diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 7713270f..5dbbae3c 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -210,9 +210,9 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T example_answer_2=_example_answer_2, example_answer_3=_example_answer_3) - def sql_user_question(self, current_time: str): + def sql_user_question(self, current_time: str, change_title: bool): return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, - rule=self.rule, current_time=current_time, error_msg=self.error_msg) + rule=self.rule, current_time=current_time, error_msg=self.error_msg,change_title = change_title) def chart_sys_question(self): return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 39ab5423..c12aaab3 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -524,7 +524,7 @@ def select_datasource(self, _session: Session): def generate_sql(self, _session: Session): # append current question self.sql_message.append(HumanMessage( - self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')))) + self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title))) self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, @@ -756,6 +756,26 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]: return chart_type + @staticmethod + def get_brief_from_sql_answer(res: str) -> Optional[str]: + json_str = extract_nested_json(res) + if json_str is None: + return None + + brief: Optional[str] + data: dict + try: + data = orjson.loads(json_str) + + if data['success']: + brief = data['brief'] + else: + return None + except Exception: + return None + + return brief + def check_save_sql(self, session: Session, res: str) -> str: sql, *_ = self.check_sql(res=res) save_sql(session=session, sql=sql, record_id=self.record.id) @@ -925,17 +945,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True, if not stream: json_result['record_id'] = self.get_record().id - # return title - if self.change_title: - if self.chat_question.question and self.chat_question.question.strip() != '': - brief = rename_chat(session=_session, - rename_object=RenameChat(id=self.get_record().chat_id, - brief=self.chat_question.question.strip()[:20])) - if in_chat: - yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' - if not stream: - json_result['title'] = brief - # select datasource if datasource is none if not self.ds: ds_res = self.select_datasource(_session) @@ -981,6 +990,19 @@ def run_task(self, in_chat: bool = True, stream: bool = True, chart_type = self.get_chart_type_from_sql_answer(full_sql_text) + # return title + if self.change_title: + llm_brief = self.get_brief_from_sql_answer(full_sql_text) + if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''): + save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20] + brief = rename_chat(session=_session, + rename_object=RenameChat(id=self.get_record().chat_id, + brief=save_brief)) + if in_chat: + yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' + if not stream: + json_result['title'] = brief + use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type in dynamic_ds_types is_page_embedded: bool = self.current_assistant and self.current_assistant.type == 4 dynamic_sql_result = None diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index def665e6..8f51253b 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -14,7 +14,8 @@ template: 4. 应用其他规则(引号、别名等) 5. 强制检查:检查语法是否正确? 6. 确定图表类型 - 7. 返回JSON结果 + 7. 确定对话标题 + 8. 返回JSON结果 query_limit: | @@ -41,7 +42,7 @@ template: system: | 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 - 你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。 + 你当前的任务是根据给定的表结构和用户问题生成SQL语句、对话标题、可能适合展示的图表类型以及该SQL中所用到的表名。 我们会在块内提供给你信息,帮助你生成SQL: 内有等信息; 其中,:提供数据库引擎及版本信息; @@ -72,7 +73,7 @@ template: 请使用JSON格式返回你的回答: - 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} + 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table","brief":"如何需要生成对话标题,在这里填写你生成的对话标题,否则不需要这个字段"}} 若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}} @@ -112,6 +113,9 @@ template: 我们目前的情况适用于单指标、多分类的场景(展示table除外) + + 是否生成对话标题在内,如果为True需要生成,否则不需要生成,生成的对话标题要求在20字以内 + {process_check} @@ -251,6 +255,9 @@ template: {question} + + {change_title} + chart: system: |