Skip to content

Commit ee16af8

Browse files
committed
feat: improve chart field name matching with original table structure
1 parent 8c309c0 commit ee16af8

File tree

5 files changed

+45
-14
lines changed

5 files changed

+45
-14
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ class ChatInfo(BaseModel):
175175
ds_type: str = ''
176176
datasource_name: str = ''
177177
datasource_exists: bool = True
178-
recommended_question: Optional[str] = None
179-
recommended_generate: Optional[bool] = False
178+
recommended_question: Optional[str] = None
179+
recommended_generate: Optional[bool] = False
180180
records: List[ChatRecord | dict] = []
181181

182182

@@ -237,9 +237,9 @@ def sql_user_question(self, current_time: str, change_title: bool):
237237
def chart_sys_question(self):
238238
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
239239

240-
def chart_user_question(self, chart_type: Optional[str] = None):
240+
def chart_user_question(self, chart_type: Optional[str] = '', schema: Optional[str] = ''):
241241
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
242-
chart_type=chart_type)
242+
chart_type=chart_type, schema=schema)
243243

244244
def analysis_sys_question(self):
245245
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies,

backend/apps/chat/task/llm.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
111111
_ds = session.get(CoreDatasource, chat_question.datasource_id)
112112
if _ds:
113113
if _ds.oid != current_user.oid:
114-
raise SingleMessageError(f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
114+
raise SingleMessageError(
115+
f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
115116
chat.datasource = _ds.id
116117
chat.engine_type = _ds.type_name
117118
# save chat
@@ -410,7 +411,8 @@ def generate_recommend_questions_task(self, _session: Session):
410411
reasoning_content=full_thinking_text,
411412
token_usage=token_usage)
412413
self.record = save_recommend_question_answer(session=_session, record_id=self.record.id,
413-
answer={'content': full_guess_text}, articles_number=self.articles_number)
414+
answer={'content': full_guess_text},
415+
articles_number=self.articles_number)
414416

415417
yield {'recommended_question': self.record.recommended_question}
416418

@@ -716,9 +718,9 @@ def generate_assistant_filter(self, _session: Session, sql, tables: List):
716718
return None
717719
return self.build_table_filter(session=_session, sql=sql, filters=filters)
718720

719-
def generate_chart(self, _session: Session, chart_type: Optional[str] = ''):
721+
def generate_chart(self, _session: Session, chart_type: Optional[str] = '', schema: Optional[str] = ''):
720722
# append current question
721-
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type)))
723+
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type, schema)))
722724

723725
self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=_session,
724726
ai_modal_id=self.chat_question.ai_modal_id,
@@ -1079,9 +1081,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
10791081
sqlbot_temp_sql_text = None
10801082
assistant_dynamic_sql = None
10811083
# row permission
1084+
sql, tables = self.check_sql(res=full_sql_text)
10821085
if ((not self.current_assistant or is_page_embedded) and is_normal_user(
10831086
self.current_user)) or use_dynamic_ds:
1084-
sql, tables = self.check_sql(res=full_sql_text)
10851087
sql_result = None
10861088

10871089
if use_dynamic_ds:
@@ -1167,7 +1169,16 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
11671169
return
11681170

11691171
# generate chart
1170-
chart_res = self.generate_chart(_session, chart_type)
1172+
used_tables_schema = self.out_ds_instance.get_db_schema(
1173+
self.ds.id, self.chat_question.question, embedding=False,
1174+
table_list=tables) if self.out_ds_instance else get_table_schema(
1175+
session=_session,
1176+
current_user=self.current_user,
1177+
ds=self.ds,
1178+
question=self.chat_question.question,
1179+
embedding=False, table_list=tables)
1180+
SQLBotLogUtil.info('used_tables_schema: \n' + used_tables_schema)
1181+
chart_res = self.generate_chart(_session, chart_type, used_tables_schema)
11711182
full_chart_text = ''
11721183
for chunk in chart_res:
11731184
full_chart_text += chunk.get('content')
@@ -1482,7 +1493,7 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
14821493
y = None
14831494
series = None
14841495
multi_quota_fields = []
1485-
multi_quota_name =None
1496+
multi_quota_name = None
14861497

14871498
if chart.get('axis'):
14881499
axis_data = chart.get('axis')

backend/apps/datasource/crud/datasource.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
425425

426426

427427
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
428-
embedding: bool = True) -> str:
428+
embedding: bool = True, table_list: list[str] = None) -> str:
429429
schema_str = ""
430430
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
431431
if len(table_objs) == 0:
@@ -435,6 +435,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
435435
tables = []
436436
all_tables = [] # temp save all tables
437437
for obj in table_objs:
438+
# 如果传入了table_list,则只处理在列表中的表
439+
if table_list is not None and obj.table.table_name not in table_list:
440+
continue
441+
438442
schema_table = ''
439443
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
440444
table_comment = ''
@@ -462,6 +466,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
462466
tables.append(t_obj)
463467
all_tables.append(t_obj)
464468

469+
# 如果没有符合过滤条件的表,直接返回
470+
if not tables:
471+
return schema_str
472+
465473
# do table embedding
466474
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
467475
tables = calc_table_embedding(tables, question)

backend/apps/system/crud/assistant.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,19 @@ def get_simple_ds_list(self):
172172
else:
173173
raise Exception("Datasource list is not found.")
174174

175-
def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> str:
175+
def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
176+
table_list: list[str] = None) -> str:
176177
ds = self.get_ds(ds_id)
177178
schema_str = ""
178179
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
179180
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
180181
tables = []
181182
i = 0
182183
for table in ds.tables:
184+
# 如果传入了 table_list,则只处理在列表中的表
185+
if table_list is not None and table.name not in table_list:
186+
continue
187+
183188
i += 1
184189
schema_table = ''
185190
schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {table.name}"

backend/templates/template.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,11 @@ template:
304304
<Instruction>
305305
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL,查询数据并进行图表展示。
306306
你当前的任务是根据给定SQL语句和用户问题,生成数据可视化图表的配置项。
307-
用户的提问在<user-question>内,<sql>内是给定需要参考的SQL,<chart-type>内是推荐你生成的图表类型
307+
用户会提供给你如下信息,帮助你生成配置项:
308+
<user-question>:用户的提问
309+
<sql>:需要参考的SQL
310+
<m-schema>:以 M-Schema 格式提供 SQL 内用到表的数据库表结构信息,你可以参考字段名与字段备注来生成图表使用到的字段名
311+
<chart-type>:推荐你生成的图表类型
308312
</Instruction>
309313
310314
你必须遵守以下规则:
@@ -455,6 +459,9 @@ template:
455459
<sql>
456460
{sql}
457461
</sql>
462+
<m-schema>
463+
{schema}
464+
</m-schema>
458465
<chart-type>
459466
{chart_type}
460467
</chart-type>

0 commit comments

Comments
 (0)