From 45249487eb6ea6995a0ef0158643d1689e98e1e1 Mon Sep 17 00:00:00 2001 From: ulleo Date: Mon, 13 Oct 2025 15:07:29 +0800 Subject: [PATCH] fix: improve sqlalchemy session in threads #294 --- backend/apps/chat/api/chat.py | 11 +- backend/apps/chat/task/llm.py | 282 +++++++++++++++++----------------- backend/apps/mcp/mcp.py | 8 +- 3 files changed, 154 insertions(+), 147 deletions(-) diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 8b6f240f..5f0199b8 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -118,7 +118,7 @@ def _return_empty(): request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '') - llm_service = await LLMService.create(current_user, request_question, current_assistant, True) + llm_service = await LLMService.create(session, current_user, request_question, current_assistant, True) llm_service.set_record(record) llm_service.run_recommend_questions_task_async() except Exception as e: @@ -147,8 +147,9 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que """ try: - llm_service = await LLMService.create(current_user, request_question, current_assistant, embedding=True) - llm_service.init_record() + llm_service = await LLMService.create(session, current_user, request_question, current_assistant, + embedding=True) + llm_service.init_record(session=session) llm_service.run_task_async() except Exception as e: traceback.print_exc() @@ -189,8 +190,8 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch request_question = ChatQuestion(chat_id=record.chat_id, question=record.question) - llm_service = await LLMService.create(current_user, request_question, current_assistant) - llm_service.run_analysis_or_predict_task_async(action_type, record) + llm_service = await LLMService.create(session, current_user, request_question, current_assistant) + llm_service.run_analysis_or_predict_task_async(session, action_type, record) except Exception as e: traceback.print_exc() diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index f0439ca2..b9d0a3b3 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -16,7 +16,7 @@ from langchain_community.utilities import SQLDatabase from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk from sqlalchemy import and_, select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, scoped_session from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil @@ -56,8 +56,7 @@ dynamic_ds_types = [1, 3] dynamic_subsql_prefix = 'select * from sqlbot_dynamic_temp_table_' -session_maker = sessionmaker(bind=engine) -db_session = session_maker() +session_maker = scoped_session(sessionmaker(bind=engine)) class LLMService: @@ -69,7 +68,7 @@ class LLMService: sql_message: List[Union[BaseMessage, dict[str, Any]]] = [] chart_message: List[Union[BaseMessage, dict[str, Any]]] = [] - session: Session = db_session + # session: Session = db_session current_user: CurrentUser current_assistant: Optional[CurrentAssistant] = None out_ds_instance: Optional[AssistantOutDs] = None @@ -85,25 +84,20 @@ class LLMService: last_execute_sql_error: str = None - def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, + def __init__(self, session: Session, current_user: CurrentUser, chat_question: ChatQuestion, current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False, embedding: bool = False, config: LLMConfig = None): self.chunk_list = [] - # engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) - # session_maker = sessionmaker(bind=engine) - # self.session = session_maker() - self.session.exec = self.session.exec if hasattr(self.session, "exec") else self.session.execute + session.exec = session.exec if hasattr(session, "exec") else session.execute self.current_user = current_user self.current_assistant = current_assistant - # chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first() chat_id = chat_question.chat_id - chat: Chat | None = self.session.get(Chat, chat_id) + chat: Chat | None = session.get(Chat, chat_id) if not chat: raise SingleMessageError(f"Chat with id {chat_id} not found") ds: CoreDatasource | AssistantOutDsSchema | None = None if chat.datasource: # Get available datasource - # ds = self.session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first() if current_assistant and current_assistant.type in dynamic_ds_types: self.out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant) ds = self.out_ds_instance.get_ds(chat.datasource) @@ -112,21 +106,22 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, chat_question.engine = ds.type + get_version(ds) chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id, chat_question.question) else: - ds = self.session.get(CoreDatasource, chat.datasource) + ds = session.get(CoreDatasource, chat.datasource) if not ds: raise SingleMessageError("No available datasource configuration found") chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds) - chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds, + chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds, question=chat_question.question, embedding=embedding) - self.generate_sql_logs = list_generate_sql_logs(session=self.session, chart_id=chat_id) - self.generate_chart_logs = list_generate_chart_logs(session=self.session, chart_id=chat_id) + self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id) + self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id) self.change_title = len(self.generate_sql_logs) == 0 chat_question.lang = get_lang_name(current_user.language) - self.ds = (ds if isinstance(ds, AssistantOutDsSchema) else CoreDatasource(**ds.model_dump())) if ds else None + self.ds = ( + ds if isinstance(ds, AssistantOutDsSchema) else CoreDatasource(**ds.model_dump())) if ds else None self.chat_question = chat_question self.config = config if no_reasoning: @@ -144,7 +139,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, self.llm = llm_instance.llm # get last_execute_sql_error - last_execute_sql_error = get_last_execute_sql_error(self.session, self.chat_question.chat_id) + last_execute_sql_error = get_last_execute_sql_error(session, self.chat_question.chat_id) if last_execute_sql_error: self.chat_question.error_msg = f''' {last_execute_sql_error} @@ -207,8 +202,8 @@ def init_messages(self): _msg = AIMessage(content=last_chart_message.get('content')) self.chart_message.append(_msg) - def init_record(self) -> ChatRecord: - self.record = save_question(session=self.session, current_user=self.current_user, question=self.chat_question) + def init_record(self, session: Session) -> ChatRecord: + self.record = save_question(session=session, current_user=self.current_user, question=self.chat_question) return self.record def get_record(self): @@ -217,8 +212,8 @@ def get_record(self): def set_record(self, record: ChatRecord): self.record = record - def get_fields_from_chart(self): - chart_info = get_chart_config(self.session, self.record.id) + def get_fields_from_chart(self, _session: Session): + chart_info = get_chart_config(_session, self.record.id) fields = [] if chart_info.get('columns') and len(chart_info.get('columns')) > 0: for column in chart_info.get('columns'): @@ -236,24 +231,24 @@ def get_fields_from_chart(self): fields.append(column_str) return fields - def generate_analysis(self): - fields = self.get_fields_from_chart() + def generate_analysis(self, _session: Session): + fields = self.get_fields_from_chart(_session) self.chat_question.fields = orjson.dumps(fields).decode() - data = get_chat_chart_data(self.session, self.record.id) + data = get_chat_chart_data(_session, self.record.id) self.chat_question.data = orjson.dumps(data.get('data')).decode() analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = [] ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, + self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, self.current_user.oid, ds_id) if SQLBotLicenseUtil.valid(): - self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS, + self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, self.current_user.oid, ds_id) analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) - self.current_logs[OperationEnum.ANALYSIS] = start_log(session=self.session, + self.current_logs[OperationEnum.ANALYSIS] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.ANALYSIS, @@ -276,7 +271,7 @@ def generate_analysis(self): analysis_msg.append(AIMessage(full_analysis_text)) - self.current_logs[OperationEnum.ANALYSIS] = end_log(session=self.session, + self.current_logs[OperationEnum.ANALYSIS] = end_log(session=_session, log=self.current_logs[ OperationEnum.ANALYSIS], full_message=[ @@ -285,25 +280,25 @@ def generate_analysis(self): for msg in analysis_msg], reasoning_content=full_thinking_text, token_usage=token_usage) - self.record = save_analysis_answer(session=self.session, record_id=self.record.id, + self.record = save_analysis_answer(session=_session, record_id=self.record.id, answer=orjson.dumps({'content': full_analysis_text}).decode()) - def generate_predict(self): - fields = self.get_fields_from_chart() + def generate_predict(self, _session: Session): + fields = self.get_fields_from_chart(_session) self.chat_question.fields = orjson.dumps(fields).decode() - data = get_chat_chart_data(self.session, self.record.id) + data = get_chat_chart_data(_session, self.record.id) self.chat_question.data = orjson.dumps(data.get('data')).decode() if SQLBotLicenseUtil.valid(): ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA, + self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, self.current_user.oid, ds_id) predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question())) - self.current_logs[OperationEnum.PREDICT_DATA] = start_log(session=self.session, + self.current_logs[OperationEnum.PREDICT_DATA] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.PREDICT_DATA, @@ -325,9 +320,9 @@ def generate_predict(self): yield chunk predict_msg.append(AIMessage(full_predict_text)) - self.record = save_predict_answer(session=self.session, record_id=self.record.id, + self.record = save_predict_answer(session=_session, record_id=self.record.id, answer=orjson.dumps({'content': full_predict_text}).decode()) - self.current_logs[OperationEnum.PREDICT_DATA] = end_log(session=self.session, + self.current_logs[OperationEnum.PREDICT_DATA] = end_log(session=_session, log=self.current_logs[ OperationEnum.PREDICT_DATA], full_message=[ @@ -337,13 +332,13 @@ def generate_predict(self): reasoning_content=full_thinking_text, token_usage=token_usage) - def generate_recommend_questions_task(self): + def generate_recommend_questions_task(self, _session: Session): # get schema if self.ds and not self.chat_question.db_schema: self.chat_question.db_schema = self.out_ds_instance.get_db_schema( self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( - session=self.session, + session=_session, current_user=self.current_user, ds=self.ds, question=self.chat_question.question, embedding=False) @@ -351,11 +346,11 @@ def generate_recommend_questions_task(self): guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question())) - old_questions = list(map(lambda q: q.strip(), get_old_questions(self.session, self.record.datasource))) + old_questions = list(map(lambda q: q.strip(), get_old_questions(_session, self.record.datasource))) guess_msg.append( HumanMessage(content=self.chat_question.guess_user_question(orjson.dumps(old_questions).decode()))) - self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = start_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.GENERATE_RECOMMENDED_QUESTIONS, @@ -378,7 +373,7 @@ def generate_recommend_questions_task(self): guess_msg.append(AIMessage(full_guess_text)) - self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = end_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = end_log(session=_session, log=self.current_logs[ OperationEnum.GENERATE_RECOMMENDED_QUESTIONS], full_message=[ @@ -387,16 +382,16 @@ def generate_recommend_questions_task(self): for msg in guess_msg], reasoning_content=full_thinking_text, token_usage=token_usage) - self.record = save_recommend_question_answer(session=self.session, record_id=self.record.id, + self.record = save_recommend_question_answer(session=_session, record_id=self.record.id, answer={'content': full_guess_text}) yield {'recommended_question': self.record.recommended_question} - def select_datasource(self): + def select_datasource(self, _session: Session): datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = [] datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question())) if self.current_assistant and self.current_assistant.type != 4: - _ds_list = get_assistant_ds(session=self.session, llm_service=self) + _ds_list = get_assistant_ds(session=_session, llm_service=self) else: stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where( and_(CoreDatasource.oid == self.current_user.oid)) @@ -406,10 +401,8 @@ def select_datasource(self): "name": ds.name, "description": ds.description } - for ds in self.session.exec(stmt) + for ds in _session.exec(stmt) ] - """ _ds_list = self.session.exec(select(CoreDatasource).options( - load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all() """ if not _ds_list: raise SingleMessageError('No available datasource configuration found') ignore_auto_select = _ds_list and len(_ds_list) == 1 @@ -420,7 +413,7 @@ def select_datasource(self): if not ignore_auto_select: if settings.TABLE_EMBEDDING_ENABLED and ( not self.current_assistant or (self.current_assistant and self.current_assistant.type != 1)): - ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.out_ds_instance, + ds = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance, self.chat_question.question, self.current_assistant) yield {'content': '{"id":' + str(ds.get('id')) + '}'} else: @@ -430,7 +423,7 @@ def select_datasource(self): datasource_msg.append( HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.CHOOSE_DATASOURCE, @@ -450,7 +443,7 @@ def select_datasource(self): yield chunk datasource_msg.append(AIMessage(full_text)) - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=_session, log=self.current_logs[ OperationEnum.CHOOSE_DATASOURCE], full_message=[ @@ -473,7 +466,7 @@ def select_datasource(self): if data.get('id') and data.get('id') != 0: _datasource = data['id'] - _chat = self.session.get(Chat, self.record.chat_id) + _chat = _session.get(Chat, self.record.chat_id) _chat.datasource = _datasource if self.current_assistant and self.current_assistant.type in dynamic_ds_types: _ds = self.out_ds_instance.get_ds(data['id']) @@ -484,28 +477,28 @@ def select_datasource(self): _engine_type = self.chat_question.engine _chat.engine_type = _ds.type else: - _ds = self.session.get(CoreDatasource, _datasource) + _ds = _session.get(CoreDatasource, _datasource) if not _ds: _datasource = None raise SingleMessageError(f"Datasource configuration with id {_datasource} not found") self.ds = CoreDatasource(**_ds.model_dump()) self.chat_question.engine = (_ds.type_name if _ds.type != 'excel' else 'PostgreSQL') + get_version( self.ds) - self.chat_question.db_schema = get_table_schema(session=self.session, + self.chat_question.db_schema = get_table_schema(session=_session, current_user=self.current_user, ds=self.ds, question=self.chat_question.question) _engine_type = self.chat_question.engine _chat.engine_type = _ds.type_name # save chat - with self.session.begin_nested(): + with _session.begin_nested(): # 为了能继续记日志,先单独处理下事务 try: - self.session.add(_chat) - self.session.flush() - self.session.refresh(_chat) - self.session.commit() + _session.add(_chat) + _session.flush() + _session.refresh(_chat) + _session.commit() except Exception as e: - self.session.rollback() + _session.rollback() raise e elif data['fail']: @@ -517,7 +510,7 @@ def select_datasource(self): _error = e if not ignore_auto_select and not settings.TABLE_EMBEDDING_ENABLED: - self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id, + self.record = save_select_datasource_answer(session=_session, record_id=self.record.id, answer=orjson.dumps({'content': full_text}).decode(), datasource=_datasource, engine_type=_engine_type) @@ -525,12 +518,12 @@ def select_datasource(self): oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid, + self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, oid, ds_id) - self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id, + self.chat_question.data_training = get_training_template(_session, self.chat_question.question, ds_id, oid) if SQLBotLicenseUtil.valid(): - self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, + self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.GENERATE_SQL, oid, ds_id) self.init_messages() @@ -538,12 +531,12 @@ def select_datasource(self): if _error: raise _error - def generate_sql(self): + def generate_sql(self, _session: Session): # append current question self.sql_message.append(HumanMessage( self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')))) - self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.GENERATE_SQL, @@ -564,16 +557,16 @@ def generate_sql(self): self.sql_message.append(AIMessage(full_sql_text)) - self.current_logs[OperationEnum.GENERATE_SQL] = end_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_SQL] = end_log(session=_session, log=self.current_logs[OperationEnum.GENERATE_SQL], full_message=[{'type': msg.type, 'content': msg.content} for msg in self.sql_message], reasoning_content=full_thinking_text, token_usage=token_usage) - self.record = save_sql_answer(session=self.session, record_id=self.record.id, + self.record = save_sql_answer(session=_session, record_id=self.record.id, answer=orjson.dumps({'content': full_sql_text}).decode()) - def generate_with_sub_sql(self, sql, sub_mappings: list): + def generate_with_sub_sql(self, _session: Session, sql, sub_mappings: list): sub_query = json.dumps(sub_mappings, ensure_ascii=False) self.chat_question.sql = sql self.chat_question.sub_query = sub_query @@ -581,7 +574,7 @@ def generate_with_sub_sql(self, sql, sub_mappings: list): dynamic_sql_msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question())) dynamic_sql_msg.append(HumanMessage(content=self.chat_question.dynamic_user_question())) - self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = start_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.GENERATE_DYNAMIC_SQL, @@ -603,7 +596,7 @@ def generate_with_sub_sql(self, sql, sub_mappings: list): dynamic_sql_msg.append(AIMessage(full_dynamic_text)) - self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = end_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = end_log(session=_session, log=self.current_logs[ OperationEnum.GENERATE_DYNAMIC_SQL], full_message=[ @@ -616,7 +609,7 @@ def generate_with_sub_sql(self, sql, sub_mappings: list): SQLBotLogUtil.info(full_dynamic_text) return full_dynamic_text - def generate_assistant_dynamic_sql(self, sql, tables: List): + def generate_assistant_dynamic_sql(self, _session: Session, sql, tables: List): ds: AssistantOutDsSchema = self.ds sub_query = [] result_dict = {} @@ -627,11 +620,11 @@ def generate_assistant_dynamic_sql(self, sql, tables: List): sub_query.append({"table": table.name, "query": f'{dynamic_subsql_prefix}{table.name}'}) if not sub_query: return None - temp_sql_text = self.generate_with_sub_sql(sql=sql, sub_mappings=sub_query) + temp_sql_text = self.generate_with_sub_sql(session=_session, sql=sql, sub_mappings=sub_query) result_dict['sqlbot_temp_sql_text'] = temp_sql_text return result_dict - def build_table_filter(self, sql: str, filters: list): + def build_table_filter(self, session: Session, sql: str, filters: list): filter = json.dumps(filters, ensure_ascii=False) self.chat_question.sql = sql self.chat_question.filter = filter @@ -639,7 +632,7 @@ def build_table_filter(self, sql: str, filters: list): permission_sql_msg.append(SystemMessage(content=self.chat_question.filter_sys_question())) permission_sql_msg.append(HumanMessage(content=self.chat_question.filter_user_question())) - self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = start_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = start_log(session=session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.GENERATE_SQL_WITH_PERMISSIONS, @@ -661,7 +654,7 @@ def build_table_filter(self, sql: str, filters: list): permission_sql_msg.append(AIMessage(full_filter_text)) - self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = end_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = end_log(session=session, log=self.current_logs[ OperationEnum.GENERATE_SQL_WITH_PERMISSIONS], full_message=[ @@ -674,14 +667,14 @@ def build_table_filter(self, sql: str, filters: list): SQLBotLogUtil.info(full_filter_text) return full_filter_text - def generate_filter(self, sql: str, tables: List): - filters = get_row_permission_filters(session=self.session, current_user=self.current_user, ds=self.ds, + def generate_filter(self, _session: Session, sql: str, tables: List): + filters = get_row_permission_filters(session=_session, current_user=self.current_user, ds=self.ds, tables=tables) if not filters: return None - return self.build_table_filter(sql=sql, filters=filters) + return self.build_table_filter(session=_session, sql=sql, filters=filters) - def generate_assistant_filter(self, sql, tables: List): + def generate_assistant_filter(self, _session: Session, sql, tables: List): ds: AssistantOutDsSchema = self.ds filters = [] for table in ds.tables: @@ -689,13 +682,13 @@ def generate_assistant_filter(self, sql, tables: List): filters.append({"table": table.name, "filter": table.rule}) if not filters: return None - return self.build_table_filter(sql=sql, filters=filters) + return self.build_table_filter(session=_session, sql=sql, filters=filters) - def generate_chart(self, chart_type: Optional[str] = ''): + def generate_chart(self, _session: Session, chart_type: Optional[str] = ''): # append current question self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type))) - self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, ai_modal_name=self.chat_question.ai_modal_name, operate=OperationEnum.GENERATE_CHART, @@ -717,9 +710,9 @@ def generate_chart(self, chart_type: Optional[str] = ''): self.chart_message.append(AIMessage(full_chart_text)) - self.record = save_chart_answer(session=self.session, record_id=self.record.id, + self.record = save_chart_answer(session=_session, record_id=self.record.id, answer=orjson.dumps({'content': full_chart_text}).decode()) - self.current_logs[OperationEnum.GENERATE_CHART] = end_log(session=self.session, + self.current_logs[OperationEnum.GENERATE_CHART] = end_log(session=_session, log=self.current_logs[OperationEnum.GENERATE_CHART], full_message=[ {'type': msg.type, 'content': msg.content} @@ -773,15 +766,15 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]: return chart_type - def check_save_sql(self, res: str) -> str: + def check_save_sql(self, session: Session, res: str) -> str: sql, *_ = self.check_sql(res=res) - save_sql(session=self.session, sql=sql, record_id=self.record.id) + save_sql(session=session, sql=sql, record_id=self.record.id) self.chat_question.sql = sql return sql - def check_save_chart(self, res: str) -> Dict[str, Any]: + def check_save_chart(self, session: Session, res: str) -> Dict[str, Any]: json_str = extract_nested_json(res) if json_str is None: @@ -821,28 +814,28 @@ def check_save_chart(self, res: str) -> Dict[str, Any]: if error: raise SingleMessageError(message) - save_chart(session=self.session, chart=orjson.dumps(chart).decode(), record_id=self.record.id) + save_chart(session=session, chart=orjson.dumps(chart).decode(), record_id=self.record.id) return chart - def check_save_predict_data(self, res: str) -> bool: + def check_save_predict_data(self, session: Session, res: str) -> bool: json_str = extract_nested_json(res) if not json_str: json_str = '' - save_predict_data(session=self.session, record_id=self.record.id, data=json_str) + save_predict_data(session=session, record_id=self.record.id, data=json_str) if json_str == '': return False return True - def save_error(self, message: str): - return save_error_message(session=self.session, record_id=self.record.id, message=message) + def save_error(self, session: Session, message: str): + return save_error_message(session=session, record_id=self.record.id, message=message) - def save_sql_data(self, data_obj: Dict[str, Any]): + def save_sql_data(self, session: Session, data_obj: Dict[str, Any]): try: data_result = data_obj.get('data') limit = 1000 @@ -853,13 +846,13 @@ def save_sql_data(self, data_obj: Dict[str, Any]): data_obj['limit'] = limit else: data_obj['data'] = data_result - return save_sql_exec_data(session=self.session, record_id=self.record.id, + return save_sql_exec_data(session=session, record_id=self.record.id, data=orjson.dumps(data_obj).decode()) except Exception as e: raise e - def finish(self): - return finish_record(session=self.session, record_id=self.record.id) + def finish(self, session: Session): + return finish_record(session=session, record_id=self.record.id) def execute_sql(self, sql: str): """Execute SQL query @@ -916,16 +909,18 @@ def run_task_cache(self, in_chat: bool = True, stream: bool = True, def run_task(self, in_chat: bool = True, stream: bool = True, finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): json_result: Dict[str, Any] = {'success': True} + _session = None try: + _session = session_maker() if self.ds: oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, + self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, oid, ds_id) - self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, + self.chat_question.data_training = get_training_template(_session, self.chat_question.question, ds_id, oid) if SQLBotLicenseUtil.valid(): - self.chat_question.custom_prompt = find_custom_prompts(self.session, + self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.GENERATE_SQL, oid, ds_id) @@ -940,7 +935,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, # return title if self.change_title: if self.chat_question.question or self.chat_question.question.strip() != '': - brief = rename_chat(session=self.session, + brief = rename_chat(session=_session, rename_object=RenameChat(id=self.get_record().chat_id, brief=self.chat_question.question.strip()[:20])) if in_chat: @@ -950,7 +945,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, # select datasource if datasource is none if not self.ds: - ds_res = self.select_datasource() + ds_res = self.select_datasource(_session) for chunk in ds_res: SQLBotLogUtil.info(chunk) @@ -965,12 +960,12 @@ def run_task(self, in_chat: bool = True, stream: bool = True, self.chat_question.db_schema = self.out_ds_instance.get_db_schema( self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( - session=self.session, + session=_session, current_user=self.current_user, ds=self.ds, question=self.chat_question.question) else: - self.validate_history_ds() + self.validate_history_ds(_session) # check connection connected = check_connection(ds=self.ds, trans=None) @@ -978,7 +973,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, raise SQLBotDBConnectionError('Connect DB failed') # generate sql - sql_res = self.generate_sql() + sql_res = self.generate_sql(_session) full_sql_text = '' for chunk in sql_res: full_sql_text += chunk.get('content') @@ -998,29 +993,29 @@ def run_task(self, in_chat: bool = True, stream: bool = True, dynamic_sql_result = None sqlbot_temp_sql_text = None assistant_dynamic_sql = None - # todo row permission + # row permission if ((not self.current_assistant or is_page_embedded) and is_normal_user( self.current_user)) or use_dynamic_ds: sql, tables = self.check_sql(res=full_sql_text) sql_result = None if use_dynamic_ds: - dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables) + dynamic_sql_result = self.generate_assistant_dynamic_sql(_session, sql, tables) sqlbot_temp_sql_text = dynamic_sql_result.get( 'sqlbot_temp_sql_text') if dynamic_sql_result else None # sql_result = self.generate_assistant_filter(sql, tables) else: - sql_result = self.generate_filter(sql, tables) # maybe no sql and tables + sql_result = self.generate_filter(_session, sql, tables) # maybe no sql and tables if sql_result: SQLBotLogUtil.info(sql_result) - sql = self.check_save_sql(res=sql_result) + sql = self.check_save_sql(session=_session, res=sql_result) elif dynamic_sql_result and sqlbot_temp_sql_text: - assistant_dynamic_sql = self.check_save_sql(res=sqlbot_temp_sql_text) + assistant_dynamic_sql = self.check_save_sql(session=_session, res=sqlbot_temp_sql_text) else: - sql = self.check_save_sql(res=full_sql_text) + sql = self.check_save_sql(session=_session, res=full_sql_text) else: - sql = self.check_save_sql(res=full_sql_text) + sql = self.check_save_sql(session=_session, res=full_sql_text) SQLBotLogUtil.info('sql: ' + sql) @@ -1051,7 +1046,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, return result = self.execute_sql(sql=real_execute_sql) - self.save_sql_data(data_obj=result) + self.save_sql_data(session=_session, data_obj=result) if in_chat: yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n' if not stream: @@ -1085,7 +1080,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, return # generate chart - chart_res = self.generate_chart(chart_type) + chart_res = self.generate_chart(_session, chart_type) full_chart_text = '' for chunk in chart_res: full_chart_text += chunk.get('content') @@ -1098,7 +1093,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, # filter chart SQLBotLogUtil.info(full_chart_text) - chart = self.check_save_chart(res=full_chart_text) + chart = self.check_save_chart(session=_session, res=full_chart_text) SQLBotLogUtil.info(chart) if not stream: @@ -1170,7 +1165,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True, {'message': 'Execute SQL Failed', 'traceback': str(e), 'type': 'exec-sql-err'}).decode() else: error_msg = orjson.dumps({'message': str(e), 'traceback': traceback.format_exc(limit=1)}).decode() - self.save_error(message=error_msg) + if _session: + self.save_error(session=_session, message=error_msg) if in_chat: yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' else: @@ -1181,7 +1177,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True, json_result['message'] = error_msg yield json_result finally: - self.finish() + self.finish(_session) + session_maker.remove() def run_recommend_questions_task_async(self): self.future = executor.submit(self.run_recommend_questions_task_cache) @@ -1191,19 +1188,26 @@ def run_recommend_questions_task_cache(self): self.chunk_list.append(chunk) def run_recommend_questions_task(self): - res = self.generate_recommend_questions_task() + try: + _session = session_maker() + res = self.generate_recommend_questions_task(_session) - for chunk in res: - if chunk.get('recommended_question'): - yield 'data:' + orjson.dumps( - {'content': chunk.get('recommended_question'), 'type': 'recommended_question'}).decode() + '\n\n' - else: - yield 'data:' + orjson.dumps( - {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), - 'type': 'recommended_question_result'}).decode() + '\n\n' + for chunk in res: + if chunk.get('recommended_question'): + yield 'data:' + orjson.dumps( + {'content': chunk.get('recommended_question'), + 'type': 'recommended_question'}).decode() + '\n\n' + else: + yield 'data:' + orjson.dumps( + {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), + 'type': 'recommended_question_result'}).decode() + '\n\n' + except Exception: + traceback.print_exc() + finally: + session_maker.remove() - def run_analysis_or_predict_task_async(self, action_type: str, base_record: ChatRecord): - self.set_record(save_analysis_predict_record(self.session, base_record, action_type)) + def run_analysis_or_predict_task_async(self, session: Session, action_type: str, base_record: ChatRecord): + self.set_record(save_analysis_predict_record(session, base_record, action_type)) self.future = executor.submit(self.run_analysis_or_predict_task_cache, action_type) def run_analysis_or_predict_task_cache(self, action_type: str): @@ -1211,13 +1215,14 @@ def run_analysis_or_predict_task_cache(self, action_type: str): self.chunk_list.append(chunk) def run_analysis_or_predict_task(self, action_type: str): + _session = None try: - + _session = session_maker() yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n' if action_type == 'analysis': # generate analysis - analysis_res = self.generate_analysis() + analysis_res = self.generate_analysis(_session) for chunk in analysis_res: yield 'data:' + orjson.dumps( {'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'), @@ -1228,7 +1233,7 @@ def run_analysis_or_predict_task(self, action_type: str): elif action_type == 'predict': # generate predict - analysis_res = self.generate_predict() + analysis_res = self.generate_predict(_session) full_text = '' for chunk in analysis_res: yield 'data:' + orjson.dumps( @@ -1237,7 +1242,7 @@ def run_analysis_or_predict_task(self, action_type: str): full_text += chunk.get('content') yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n' - _data = self.check_save_predict_data(res=full_text) + _data = self.check_save_predict_data(session=_session, res=full_text) if _data: yield 'data:' + orjson.dumps({'type': 'predict-success'}).decode() + '\n\n' else: @@ -1245,31 +1250,32 @@ def run_analysis_or_predict_task(self, action_type: str): yield 'data:' + orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n' - self.finish() + self.finish(_session) except Exception as e: error_msg: str if isinstance(e, SingleMessageError): error_msg = str(e) else: error_msg = orjson.dumps({'message': str(e), 'traceback': traceback.format_exc(limit=1)}).decode() - self.save_error(message=error_msg) + if _session: + self.save_error(session=_session, message=error_msg) yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' finally: # end - pass + session_maker.remove() - def validate_history_ds(self): + def validate_history_ds(self, session: Session): _ds = self.ds if not self.current_assistant or self.current_assistant.type == 4: try: - current_ds = self.session.get(CoreDatasource, _ds.id) + current_ds = session.get(CoreDatasource, _ds.id) if not current_ds: raise SingleMessageError('chat.ds_is_invalid') except Exception as e: raise SingleMessageError("chat.ds_is_invalid") else: try: - _ds_list: list[dict] = get_assistant_ds(session=self.session, llm_service=self) + _ds_list: list[dict] = get_assistant_ds(session=session, llm_service=self) match_ds = any(item.get("id") == _ds.id for item in _ds_list) if not match_ds: type = self.current_assistant.type diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py index 433e4727..76f3feea 100644 --- a/backend/apps/mcp/mcp.py +++ b/backend/apps/mcp/mcp.py @@ -111,8 +111,8 @@ async def mcp_question(session: SessionDep, chat: McpQuestion): mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question) try: - llm_service = await LLMService.create(session_user, mcp_chat) - llm_service.init_record() + llm_service = await LLMService.create(session, session_user, mcp_chat) + llm_service.init_record(session=session) llm_service.run_task_async(False, chat.stream) except Exception as e: traceback.print_exc() @@ -167,8 +167,8 @@ async def mcp_assistant(session: SessionDep, chat: McpAssistant): mcp_chat = ChatQuestion(chat_id=c.id, question=chat.question) # ask try: - llm_service = await LLMService.create(session_user, mcp_chat, mcp_assistant_header) - llm_service.init_record() + llm_service = await LLMService.create(session, session_user, mcp_chat, mcp_assistant_header) + llm_service.init_record(session=session) llm_service.run_task_async(False, chat.stream, ChatFinishStep.QUERY_DATA) except Exception as e: traceback.print_exc()