diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 1d4815fe..6f34b69b 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -1,6 +1,7 @@ import asyncio import io import traceback +from typing import Optional import orjson import pandas as pd @@ -107,7 +108,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser): @router.post("/recommend_questions/{chat_record_id}") async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, - current_assistant: CurrentAssistant): + current_assistant: CurrentAssistant, articles_number: Optional[int] = 4): def _return_empty(): yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n' @@ -121,6 +122,7 @@ def _return_empty(): llm_service = await LLMService.create(session, current_user, request_question, current_assistant, True) llm_service.set_record(record) + llm_service.set_articles_number(articles_number) llm_service.run_recommend_questions_task_async() except Exception as e: traceback.print_exc() diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 5dbbae3c..966fc7ae 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -212,7 +212,8 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T def sql_user_question(self, current_time: str, change_title: bool): return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, - rule=self.rule, current_time=current_time, error_msg=self.error_msg,change_title = change_title) + rule=self.rule, current_time=current_time, error_msg=self.error_msg, + change_title=change_title) def chart_sys_question(self): return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang) @@ -240,8 +241,8 @@ def datasource_sys_question(self): def datasource_user_question(self, datasource_list: str = "[]"): return get_datasource_template()['user'].format(question=self.question, data=datasource_list) - def guess_sys_question(self): - return get_guess_question_template()['system'].format(lang=self.lang) + def guess_sys_question(self, articles_number: int = 4): + return get_guess_question_template()['system'].format(lang=self.lang, articles_number=articles_number) def guess_user_question(self, old_questions: str = "[]"): return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema, diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index c12aaab3..39622eb3 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -84,6 +84,7 @@ class LLMService: future: Future last_execute_sql_error: str = None + articles_number: int = 4 def __init__(self, session: Session, current_user: CurrentUser, chat_question: ChatQuestion, current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False, @@ -213,6 +214,9 @@ def get_record(self): def set_record(self, record: ChatRecord): self.record = record + def set_articles_number(self, articles_number: int): + self.articles_number = articles_number + def get_fields_from_chart(self, _session: Session): chart_info = get_chart_config(_session, self.record.id) return format_chart_fields(chart_info) @@ -330,7 +334,7 @@ def generate_recommend_questions_task(self, _session: Session): embedding=False) guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] - guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question())) + guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number))) old_questions = list(map(lambda q: q.strip(), get_old_questions(_session, self.record.datasource))) guess_msg.append( diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index 8f51253b..83056865 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -374,7 +374,7 @@ template: ### 请使用语言:{lang} 回答,不需要输出深度思考过程 ### 说明: - 您的任务是根据给定的表结构,用户问题以及以往用户提问,推测用户接下来可能提问的1-4个问题。 + 您的任务是根据给定的表结构,用户问题以及以往用户提问,推测用户接下来可能提问的1-{articles_number}个问题。 请遵循以下规则: - 推测的问题需要与提供的表结构相关,生成的提问例子如:["查询所有用户数据","使用饼图展示各产品类型的占比","使用折线图展示销售额趋势",...] - 推测问题如果涉及图形展示,支持的图形类型为:表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie) @@ -385,7 +385,7 @@ template: - 如果用户没有提问且没有以往用户提问,则仅根据提供的表结构推测问题 - 生成的推测问题使用JSON格式返回: ["推测问题1", "推测问题2", "推测问题3", "推测问题4"] - - 最多返回4个你推测出的结果 + - 最多返回{articles_number}个你推测出的结果 - 若无法推测,则返回空数据JSON: [] - 若你的给出的JSON不是{lang}的,则必须翻译为{lang} diff --git a/frontend/src/api/chat.ts b/frontend/src/api/chat.ts index d840b31e..40deaf3a 100644 --- a/frontend/src/api/chat.ts +++ b/frontend/src/api/chat.ts @@ -328,8 +328,12 @@ export const chatApi = { predict: (record_id: number | undefined, controller?: AbortController) => { return request.fetchStream(`/chat/record/${record_id}/predict`, {}, controller) }, - recommendQuestions: (record_id: number | undefined, controller?: AbortController) => { - return request.fetchStream(`/chat/recommend_questions/${record_id}`, {}, controller) + recommendQuestions: ( + record_id: number | undefined, + controller?: AbortController, + params: any + ) => { + return request.fetchStream(`/chat/recommend_questions/${record_id}${params}`, {}, controller) }, recentQuestions: (datasource_id?: number): Promise => { return request.get(`/chat/recent_questions/${datasource_id}`) diff --git a/frontend/src/views/chat/QuickQuestion.vue b/frontend/src/views/chat/QuickQuestion.vue index 80f5b447..fa6fe508 100644 --- a/frontend/src/views/chat/QuickQuestion.vue +++ b/frontend/src/views/chat/QuickQuestion.vue @@ -11,7 +11,7 @@ const recommendQuestionRef = ref() const recentQuestionRef = ref() const popoverRef = ref() const getRecommendQuestions = () => { - recommendQuestionRef.value.getRecommendQuestions() + recommendQuestionRef.value.getRecommendQuestions(10) } const retrieveQuestions = () => { diff --git a/frontend/src/views/chat/RecommendQuestion.vue b/frontend/src/views/chat/RecommendQuestion.vue index c7ff9385..a9562b55 100644 --- a/frontend/src/views/chat/RecommendQuestion.vue +++ b/frontend/src/views/chat/RecommendQuestion.vue @@ -58,12 +58,13 @@ function clickQuestion(question: string): void { const stopFlag = ref(false) -async function getRecommendQuestions() { +async function getRecommendQuestions(articles_number: number) { stopFlag.value = false loading.value = true try { const controller: AbortController = new AbortController() - const response = await chatApi.recommendQuestions(props.recordId, controller) + const params = articles_number ? '?articles_number=' + articles_number : '' + const response = await chatApi.recommendQuestions(props.recordId, controller, params) const reader = response.body.getReader() const decoder = new TextDecoder('utf-8')