diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index dbb9e5b7..d8359f1c 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -413,50 +413,50 @@ def select_datasource(self, _session: Session): if not ignore_auto_select: if settings.TABLE_EMBEDDING_ENABLED and ( not self.current_assistant or (self.current_assistant and self.current_assistant.type != 1)): - ds = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance, + _ds_list = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance, self.chat_question.question, self.current_assistant) - yield {'content': '{"id":' + str(ds.get('id')) + '}'} - else: - _ds_list_dict = [] - for _ds in _ds_list: - _ds_list_dict.append(_ds) - datasource_msg.append( - HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) - - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=_session, - ai_modal_id=self.chat_question.ai_modal_id, - ai_modal_name=self.chat_question.ai_modal_name, - operate=OperationEnum.CHOOSE_DATASOURCE, - record_id=self.record.id, - full_message=[{'type': msg.type, - 'content': msg.content} - for - msg in datasource_msg]) - - token_usage = {} - res = process_stream(self.llm.stream(datasource_msg), token_usage) - for chunk in res: - if chunk.get('content'): - full_text += chunk.get('content') - if chunk.get('reasoning_content'): - full_thinking_text += chunk.get('reasoning_content') - yield chunk - datasource_msg.append(AIMessage(full_text)) - - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=_session, - log=self.current_logs[ - OperationEnum.CHOOSE_DATASOURCE], - full_message=[ - {'type': msg.type, - 'content': msg.content} - for msg in datasource_msg], - reasoning_content=full_thinking_text, - token_usage=token_usage) - - json_str = extract_nested_json(full_text) - if json_str is None: - raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}') - ds = orjson.loads(json_str) + # yield {'content': '{"id":' + str(ds.get('id')) + '}'} + + _ds_list_dict = [] + for _ds in _ds_list: + _ds_list_dict.append(_ds) + datasource_msg.append( + HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=_session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.CHOOSE_DATASOURCE, + record_id=self.record.id, + full_message=[{'type': msg.type, + 'content': msg.content} + for + msg in datasource_msg]) + + token_usage = {} + res = process_stream(self.llm.stream(datasource_msg), token_usage) + for chunk in res: + if chunk.get('content'): + full_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk + datasource_msg.append(AIMessage(full_text)) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=_session, + log=self.current_logs[ + OperationEnum.CHOOSE_DATASOURCE], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in datasource_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) + + json_str = extract_nested_json(full_text) + if json_str is None: + raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}') + ds = orjson.loads(json_str) _error: Exception | None = None _datasource: int | None = None diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py index 49fa4020..657a3bf0 100644 --- a/backend/apps/datasource/embedding/ds_embedding.py +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -9,6 +9,7 @@ from apps.datasource.embedding.utils import cosine_similarity from apps.datasource.models.datasource import CoreDatasource from apps.system.crud.assistant import AssistantOutDs +from common.core.config import settings from common.core.deps import CurrentAssistant from common.core.deps import SessionDep, CurrentUser from common.utils.utils import SQLBotLogUtil @@ -45,8 +46,9 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o [{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")} for ele in _list])) - ds = _list[0].get('ds') - return {"id": ds.id, "name": ds.name, "description": ds.description} + ds_l = _list[:settings.DS_EMBEDDING_COUNT] + return [{"id": obj.get('ds').id, "name": obj.get('ds').name, "description": obj.get('ds').description} + for obj in ds_l] except Exception: traceback.print_exc() else: @@ -81,8 +83,9 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o [{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")} for ele in _list])) - ds = _list[0].get('ds') - return {"id": ds.id, "name": ds.name, "description": ds.description} + ds_l = _list[:settings.DS_EMBEDDING_COUNT] + return [{"id": obj.get('ds').id, "name": obj.get('ds').name, "description": obj.get('ds').description} + for obj in ds_l] except Exception: traceback.print_exc() return _list diff --git a/backend/common/core/config.py b/backend/common/core/config.py index 0850df2d..bf9d0759 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -109,6 +109,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: TABLE_EMBEDDING_ENABLED: bool = True TABLE_EMBEDDING_COUNT: int = 10 + DS_EMBEDDING_COUNT: int = 10 settings = Settings() # type: ignore