Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Union

from fastapi import Body
from pydantic import BaseModel
Expand All @@ -9,13 +9,14 @@
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import SQLModel, Field

from apps.db.constant import DB
from apps.template.filter.generator import get_permissions_template
from apps.template.generate_analysis.generator import get_analysis_template
from apps.template.generate_chart.generator import get_chart_template
from apps.template.generate_dynamic.generator import get_dynamic_template
from apps.template.generate_guess_question.generator import get_guess_question_template
from apps.template.generate_predict.generator import get_predict_template
from apps.template.generate_sql.generator import get_sql_template
from apps.template.generate_sql.generator import get_sql_template, get_sql_example_template
from apps.template.select_datasource.generator import get_datasource_template


Expand Down Expand Up @@ -182,10 +183,27 @@ class AiModelQuestion(BaseModel):
custom_prompt: str = ""
error_msg: str = ""

def sql_sys_question(self):
def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
_sql_template = get_sql_example_template(db_type)
_base_sql_rules = _sql_template['quot_rule'] + _sql_template['limit_rule'] + _sql_template['other_rule']
_query_limit = get_sql_template()['query_limit'] if enable_query_limit else ''
_sql_examples = _sql_template['basic_example']
_example_engine = _sql_template['example_engine']
_example_answer_1 = _sql_template['example_answer_1_with_limit'] if enable_query_limit else _sql_template[
'example_answer_1']
_example_answer_2 = _sql_template['example_answer_2_with_limit'] if enable_query_limit else _sql_template[
'example_answer_2']
_example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[
'example_answer_3']
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
lang=self.lang, terminologies=self.terminologies,
data_training=self.data_training, custom_prompt=self.custom_prompt)
data_training=self.data_training, custom_prompt=self.custom_prompt,
base_sql_rules=_base_sql_rules, query_limit=_query_limit,
basic_sql_examples=_sql_examples,
example_engine=_example_engine,
example_answer_1=_example_answer_1,
example_answer_2=_example_answer_2,
example_answer_3=_example_answer_3)

def sql_user_question(self, current_time: str):
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
Expand Down
3 changes: 2 additions & 1 deletion backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def init_messages(self):

self.sql_message = []
# add sys prompt
self.sql_message.append(SystemMessage(content=self.chat_question.sql_sys_question()))
self.sql_message.append(SystemMessage(
content=self.chat_question.sql_sys_question(self.ds.type, settings.GENERATE_SQL_QUERY_LIMIT_ENABLED)))
if last_sql_messages is not None and len(last_sql_messages) > 0:
# limit count
for last_sql_message in last_sql_messages[count_limit:]:
Expand Down
36 changes: 20 additions & 16 deletions backend/apps/db/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,33 @@ def __init__(self, type_name):


class DB(Enum):
mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy)
sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy)
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy)
excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy)
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy)
ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy)
dm = ('dm', '达梦', '"', '"', ConnectType.py_driver)
doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver)
redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver)
es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver)
kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver)
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver)

def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType):
excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL')
redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver, 'AWS_Redshift')
ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy, 'ClickHouse')
dm = ('dm', '达梦', '"', '"', ConnectType.py_driver, 'DM')
doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver, 'Doris')
es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver, 'Elasticsearch')
kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver, 'Kingbase')
sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy, 'Microsoft_SQL_Server')
mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy, 'MySQL')
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle')
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL')
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks')

def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str):
self.type = type
self.db_name = db_name
self.prefix = prefix
self.suffix = suffix
self.connect_type = connect_type
self.template_name = template_name

@classmethod
def get_db(cls, type):
def get_db(cls, type, default_if_none=False):
for db in cls:
if db.type == type:
return db
raise ValueError(f"Invalid db type: {type}")
if default_if_none:
return DB.pg
else:
raise ValueError(f"Invalid db type: {type}")
10 changes: 9 additions & 1 deletion backend/apps/template/generate_sql/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from apps.template.template import get_base_template
from typing import Union

from apps.db.constant import DB
from apps.template.template import get_base_template, get_sql_template as get_base_sql_template


def get_sql_template():
template = get_base_template()
return template['template']['sql']


def get_sql_example_template(db_type: Union[str, DB]):
template = get_base_sql_template(db_type)
return template['template']
65 changes: 57 additions & 8 deletions backend/apps/template/template.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,64 @@
import yaml
from pathlib import Path
from functools import cache
from typing import Union

base_template = None
from apps.db.constant import DB

# 基础路径配置
PROJECT_ROOT = Path(__file__).parent.parent.parent
TEMPLATES_DIR = PROJECT_ROOT / 'templates'
BASE_TEMPLATE_PATH = TEMPLATES_DIR / 'template.yaml'
SQL_TEMPLATES_DIR = TEMPLATES_DIR / 'sql_examples'

def load():
with open('./template.yaml', 'r', encoding='utf-8') as f:
global base_template
base_template = yaml.load(f, Loader=yaml.SafeLoader)

@cache
def _load_template_file(file_path: Path):
"""内部函数:加载并解析YAML文件"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Template file not found at {file_path}")
except yaml.YAMLError as e:
raise ValueError(f"Error parsing YAML file {file_path}: {e}")


def get_base_template():
if not base_template:
load()
return base_template
"""获取基础模板(自动缓存)"""
return _load_template_file(BASE_TEMPLATE_PATH)


def get_sql_template(db_type: Union[str, DB]):
# 处理输入参数
if isinstance(db_type, str):
# 如果是字符串,查找对应的枚举值,找不到则使用默认的 DB.pg
db_enum = DB.get_db(db_type, default_if_none=True)
elif isinstance(db_type, DB):
db_enum = db_type
else:
db_enum = DB.pg

# 使用 template_name 作为文件名
template_path = SQL_TEMPLATES_DIR / f"{db_enum.template_name}.yaml"

return _load_template_file(template_path)


def get_all_sql_templates():
"""获取所有支持的数据库模板"""
templates = {}
for db in DB:
try:
templates[db.type] = get_sql_template(db)
except FileNotFoundError:
# 如果某个数据库的模板文件不存在,跳过
continue
return templates


def reload_all_templates():
"""清空所有模板缓存"""
_load_template_file.cache_clear()


2 changes: 2 additions & 0 deletions backend/common/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT

GENERATE_SQL_QUERY_LIMIT_ENABLED: bool = True

PARSE_REASONING_BLOCK_ENABLED: bool = True
DEFAULT_REASONING_CONTENT_START: str = '<think>'
DEFAULT_REASONING_CONTENT_END: str = '</think>'
Expand Down
86 changes: 86 additions & 0 deletions backend/templates/sql_examples/AWS_Redshift.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
template:
quot_rule: |
<rule>
必须对数据库名、表名、字段名、别名外层加双引号(")。
<note>
1. 点号(.)不能包含在引号内,必须写成 "schema"."table"
2. 即使标识符不含特殊字符或非关键字,也需强制加双引号
3. Redshift 默认将未加引号的标识符转为小写
</note>
</rule>

limit_rule: |
<rule>
使用 LIMIT 或 FETCH FIRST 限制行数(Redshift 兼容 PostgreSQL)
<note>
1. 标准写法:LIMIT 100
2. 可选写法:FETCH FIRST 100 ROWS ONLY
</note>
</rule>

other_rule: |
<rule>必须为每个表生成别名(不加AS)</rule>
<rule>禁止使用星号(*),必须明确字段名</rule>
<rule>中文/特殊字符字段需保留原名并添加英文别名</rule>
<rule>函数字段必须加别名</rule>
<rule>百分比字段保留两位小数并以%结尾(使用ROUND+CONCAT)</rule>
<rule>避免与Redshift关键字冲突(如USER/GROUP/ORDER等)</rule>

basic_example: |
<basic-examples>
<intro>
📌 以下示例严格遵循<Rules>中的 AWS Redshift 规范,展示符合要求的 SQL 写法与典型错误案例。
⚠️ 注意:示例中的表名、字段名均为演示虚构,实际使用时需替换为用户提供的真实标识符。
🔍 重点观察:
1. 双引号包裹所有数据库对象的规范用法
2. 中英别名/百分比/函数等特殊字段的处理
3. 关键字冲突的规避方式
</intro>
<example>
<input>查询 TEST.SALES 表的前100条订单(含百分比计算)</input>
<output-bad>
SELECT * FROM TEST.SALES LIMIT 100 -- 错误:未加引号、使用星号
SELECT "订单ID", "金额" FROM "TEST"."SALES" "t1" FETCH FIRST 100 ROWS ONLY -- 错误:缺少英文别名
SELECT COUNT("订单ID") FROM "TEST"."SALES" "t1" -- 错误:函数未加别名
</output-bad>
<output-good>
SELECT
"t1"."订单ID" AS "order_id",
"t1"."金额" AS "amount",
COUNT("t1"."订单ID") AS "total_orders",
CONCAT(ROUND("t1"."折扣率" * 100, 2), '%') AS "discount_percent"
FROM "TEST"."SALES" "t1"
LIMIT 100
</output-good>
</example>

<example>
<input>统计用户表 PUBLIC.USERS(含关键字字段user)的活跃占比</input>
<output-bad>
SELECT user, status FROM PUBLIC.USERS -- 错误:未处理关键字和引号
SELECT "user", ROUND(active_ratio) FROM "PUBLIC"."USERS" -- 错误:百分比格式错误
</output-bad>
<output-good>
SELECT
"u"."user" AS "user_account",
CONCAT(ROUND("u"."active_ratio" * 100, 2), '%') AS "active_percent"
FROM "PUBLIC"."USERS" "u"
WHERE "u"."status" = 1
FETCH FIRST 1000 ROWS ONLY
</output-good>
</example>
</basic-examples>

example_engine: AWS Redshift 1.0
example_answer_1: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\"","tables":["sample_country_gdp"],"chart-type":"line"}
example_answer_1_with_limit: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000 OFFSET 0","tables":["sample_country_gdp"],"chart-type":"line"}
example_answer_2: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC","tables":["sample_country_gdp"],"chart-type":"pie"}
example_answer_2_with_limit: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000 OFFSET 0","tables":["sample_country_gdp"],"chart-type":"pie"}
example_answer_3: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国'","tables":["sample_country_gdp"],"chart-type":"table"}
example_answer_3_with_limit: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000 OFFSET 0","tables":["sample_country_gdp"],"chart-type":"table"}
90 changes: 90 additions & 0 deletions backend/templates/sql_examples/ClickHouse.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
template:
quot_rule: |
<rule>
必须对数据库名、表名、字段名、别名外层加双引号(")。
<note>
1. 点号(.)不能包含在引号内,必须写成 "database"."table"
2. ClickHouse 严格区分大小写,必须通过引号保留原始大小写
3. 嵌套字段使用点号连接:`"json_column.field"`
</note>
</rule>

limit_rule: |
<rule>
行数限制使用标准SQL语法:
<note>
1. 标准写法:LIMIT [count]
2. 分页写法:LIMIT [count] OFFSET [start]
3. 禁止使用原生 `topk()` 等函数替代
</note>
</rule>

other_rule: |
<rule>必须为每个表生成简短别名(如t1/t2)</rule>
<rule>禁止使用星号(*),必须明确字段名</rule>
<rule>JSON字段需用点号语法访问:`"column.field"`</rule>
<rule>函数字段必须加别名</rule>
<rule>百分比显示为:`ROUND(x*100,2) || '%'`</rule>
<rule>避免与ClickHouse关键字冲突(如`timestamp`/`default`)</rule>

basic_example: |
<basic-examples>
<intro>
📌 以下示例严格遵循<Rules>中的 ClickHouse 规范,展示符合要求的 SQL 写法与典型错误案例。
⚠️ 注意:示例中的表名、字段名均为演示虚构,实际使用时需替换为用户提供的真实标识符。
🔍 重点观察:
1. 双引号包裹所有数据库对象的规范用法
2. 中英别名/百分比/函数等特殊字段的处理
3. 关键字冲突的规避方式
</intro>
<example>
<input>查询 events 表的前100条错误日志(含JSON字段)</input>
<output-bad>
SELECT * FROM default.events LIMIT 100 -- 错误1:使用星号
SELECT message FROM "default"."events" WHERE level = 'error' -- 错误2:未处理JSON字段
SELECT "message", "extra.error_code" FROM events LIMIT 100 -- 错误3:表名未加引号
</output-bad>
<output-good>
SELECT
"e"."message" AS "log_content",
"e"."extra"."error_code" AS "error_id",
toDateTime("e"."timestamp") AS "log_time"
FROM "default"."events" "e"
WHERE "e"."level" = 'error'
LIMIT 100
</output-good>
</example>

<example>
<input>统计各地区的错误率Top 5(含百分比)</input>
<output-bad>
SELECT region, COUNT(*) FROM events GROUP BY region -- 错误1:使用COUNT(*)
SELECT "region", MAX("count") FROM "events" GROUP BY 1 -- 错误2:使用序号分组
</output-bad>
<output-good>
SELECT
"e"."region" AS "area",
COUNT(*) AS "total",
COUNTIf("e"."level" = 'error') AS "error_count",
ROUND(error_count * 100.0 / total, 2) || '%' AS "error_rate"
FROM "default"."events" "e"
GROUP BY "e"."region"
ORDER BY "error_rate" DESC
LIMIT 5
</output-good>
</example>
</basic-examples>

example_engine: ClickHouse 23.3
example_answer_1: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\"","tables":["sample_country_gdp"],"chart-type":"line"}
example_answer_1_with_limit: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"}
example_answer_2: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC","tables":["sample_country_gdp"],"chart-type":"pie"}
example_answer_2_with_limit: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"}
example_answer_3: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国'","tables":["sample_country_gdp"],"chart-type":"table"}
example_answer_3_with_limit: |
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"}
Loading