From 3e4873c14172a866956b28f6bb56bd7d1f2a87a2 Mon Sep 17 00:00:00 2001 From: wangjiahao <1522128093@qq.com> Date: Tue, 2 Dec 2025 15:38:11 +0800 Subject: [PATCH] fix: Fixed the issue where the conversation title was not adjusted when the first sentence failed --- backend/apps/chat/curd/chat.py | 12 ++++++++++++ backend/apps/chat/models/chat_model.py | 4 +++- backend/apps/chat/task/llm.py | 20 +++++++++++--------- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index 0872c6f4..f6dafb44 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -28,6 +28,10 @@ def get_chat_record_by_id(session: SessionDep, record_id: int): engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by) return record +def get_chat(session: SessionDep, chat_id: int) -> Chat: + statement = select(Chat).where(Chat.id == chat_id) + chat = session.exec(statement).scalars().first() + return chat def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]: oid = current_user.oid if current_user.oid is not None else 1 @@ -57,6 +61,7 @@ def rename_chat(session: SessionDep, rename_object: RenameChat) -> str: raise Exception(f"Chat with id {rename_object.id} not found") chat.brief = rename_object.brief.strip()[:20] + chat.brief_generate = rename_object.brief_generate session.add(chat) session.flush() session.refresh(chat) @@ -340,6 +345,13 @@ def format_record(record: ChatRecordResult): return _dict +def get_chat_brief_generate(session: SessionDep, chat_id: int): + chat = get_chat(session=session,chat_id=chat_id) + if chat is not None and chat.brief_generate is not None: + return chat.brief_generate + else: + return False + def list_generate_sql_logs(session: SessionDep, chart_id: int) -> List[ChatLog]: stmt = select(ChatLog).where( diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 966fc7ae..0786c2f6 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -78,7 +78,8 @@ class Chat(SQLModel, table=True): datasource: int = Field(sa_column=Column(BigInteger, nullable=True)) engine_type: str = Field(max_length=64) origin: Optional[int] = Field( - sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant + sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant + brief_generate: bool = Field(default=False) class ChatRecord(SQLModel, table=True): @@ -149,6 +150,7 @@ class CreateChat(BaseModel): class RenameChat(BaseModel): id: int = None brief: str = '' + brief_generate: bool = True class ChatInfo(BaseModel): diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 39622eb3..47ab30a9 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -29,7 +29,7 @@ save_select_datasource_answer, save_recommend_question_answer, \ get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \ get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \ - get_last_execute_sql_error, format_json_data, format_chart_fields + get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ ChatFinishStep, AxisObj from apps.data_training.curd.data_training import get_training_template @@ -117,7 +117,7 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id) self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id) - self.change_title = len(self.generate_sql_logs) == 0 + self.change_title = not get_chat_brief_generate(session=session, chat_id=chat_id) chat_question.lang = get_lang_name(current_user.language) @@ -528,7 +528,8 @@ def select_datasource(self, _session: Session): def generate_sql(self, _session: Session): # append current question self.sql_message.append(HumanMessage( - self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title))) + self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + change_title=self.change_title))) self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, @@ -997,11 +998,13 @@ def run_task(self, in_chat: bool = True, stream: bool = True, # return title if self.change_title: llm_brief = self.get_brief_from_sql_answer(full_sql_text) - if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''): - save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20] + llm_brief_generated = bool(llm_brief) + if llm_brief_generated or (self.chat_question.question and self.chat_question.question.strip() != ''): + save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[ + :20] brief = rename_chat(session=_session, rename_object=RenameChat(id=self.get_record().chat_id, - brief=save_brief)) + brief=save_brief, brief_generate=llm_brief_generated)) if in_chat: yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' if not stream: @@ -1084,7 +1087,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True, for field in result.get('fields'): _column_list.append(AxisObj(name=field, value=field)) - md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, result.get('data')) + md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, + result.get('data')) # data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data')) @@ -1203,8 +1207,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True, self.finish(_session) session_maker.remove() - - def run_recommend_questions_task_async(self): self.future = executor.submit(self.run_recommend_questions_task_cache)