diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index dc446685..1d4815fe 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -10,7 +10,7 @@ from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \ delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \ - format_json_data, format_json_list_data, get_chart_config + format_json_data, format_json_list_data, get_chart_config, list_recent_questions from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj from apps.chat.task.llm import LLMService from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans @@ -132,6 +132,10 @@ def _err(_e: Exception): return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") +@router.get("/recent_questions/{datasource_id}") +async def recommend_questions(session: SessionDep, current_user: CurrentUser, datasource_id: int): + return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id) + @router.post("/question") async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index 48b3f054..0872c6f4 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -1,5 +1,6 @@ import datetime from typing import List +from sqlalchemy import desc, func import orjson import sqlparse @@ -8,7 +9,8 @@ from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion, ChatLog, \ TypeEnum, OperationEnum, ChatRecordResult -from apps.datasource.models.datasource import CoreDatasource +from apps.datasource.crud.recommended_problem import get_datasource_recommended, get_datasource_recommended_chart +from apps.datasource.models.datasource import CoreDatasource, DsRecommendedProblem from apps.system.crud.assistant import AssistantOutDsFactory from common.core.deps import CurrentAssistant, SessionDep, CurrentUser from common.utils.utils import extract_nested_json @@ -34,6 +36,21 @@ def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]: return chart_list +def list_recent_questions(session: SessionDep, current_user: CurrentUser, datasource_id: int) -> List[str]: + chat_records = ( + session.query(ChatRecord.question) + .filter( + ChatRecord.datasource == datasource_id, + ChatRecord.question.isnot(None) + ) + .group_by(ChatRecord.question) + .order_by(desc(func.max(ChatRecord.create_time))) + .limit(10) + .all() + ) + return [record[0] for record in chat_records] if chat_records else [] + + def rename_chat(session: SessionDep, rename_object: RenameChat) -> str: chat = session.get(Chat, rename_object.id) if not chat: @@ -70,6 +87,7 @@ def get_chart_config(session: SessionDep, chart_record_id: int): pass return {} + def format_chart_fields(chart_info: dict): fields = [] if chart_info.get('columns') and len(chart_info.get('columns')) > 0: @@ -88,6 +106,7 @@ def format_chart_fields(chart_info: dict): fields.append(column_str) return fields + def get_last_execute_sql_error(session: SessionDep, chart_id: int): stmt = select(ChatRecord.error).where(and_(ChatRecord.chat_id == chart_id)).order_by( ChatRecord.create_time.desc()).limit(1) @@ -396,6 +415,12 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: record.finish = True record.create_time = datetime.datetime.now() record.create_by = current_user.id + if ds.recommended_config == 2: + questions = get_datasource_recommended_chart(session, ds.id) + record.recommended_question = orjson.dumps(questions).decode() + record.recommended_question_answer = orjson.dumps({ + "content": questions + }).decode() _record = ChatRecord(**record.model_dump()) diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 5e7f0136..5dbbae3c 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -185,8 +185,12 @@ class AiModelQuestion(BaseModel): def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True): _sql_template = get_sql_example_template(db_type) - _query_limit = get_sql_template()['query_limit'] if enable_query_limit else get_sql_template()['no_query_limit'] - _base_sql_rules = _sql_template['quot_rule'] + _query_limit + _sql_template['limit_rule'] + _sql_template['other_rule'] + _base_template = get_sql_template() + _process_check = _sql_template.get('process_check') if _sql_template.get('process_check') else _base_template[ + 'process_check'] + _query_limit = _base_template['query_limit'] if enable_query_limit else _base_template['no_query_limit'] + _base_sql_rules = _sql_template['quot_rule'] + _query_limit + _sql_template['limit_rule'] + _sql_template[ + 'other_rule'] _sql_examples = _sql_template['basic_example'] _example_engine = _sql_template['example_engine'] _example_answer_1 = _sql_template['example_answer_1_with_limit'] if enable_query_limit else _sql_template[ @@ -195,19 +199,20 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T 'example_answer_2'] _example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[ 'example_answer_3'] - return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, - lang=self.lang, terminologies=self.terminologies, - data_training=self.data_training, custom_prompt=self.custom_prompt, - base_sql_rules=_base_sql_rules, - basic_sql_examples=_sql_examples, - example_engine=_example_engine, - example_answer_1=_example_answer_1, - example_answer_2=_example_answer_2, - example_answer_3=_example_answer_3) - - def sql_user_question(self, current_time: str): + return _base_template['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, + lang=self.lang, terminologies=self.terminologies, + data_training=self.data_training, custom_prompt=self.custom_prompt, + process_check=_process_check, + base_sql_rules=_base_sql_rules, + basic_sql_examples=_sql_examples, + example_engine=_example_engine, + example_answer_1=_example_answer_1, + example_answer_2=_example_answer_2, + example_answer_3=_example_answer_3) + + def sql_user_question(self, current_time: str, change_title: bool): return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, - rule=self.rule, current_time=current_time, error_msg=self.error_msg) + rule=self.rule, current_time=current_time, error_msg=self.error_msg,change_title = change_title) def chart_sys_question(self): return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 801cbf16..c12aaab3 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -524,7 +524,7 @@ def select_datasource(self, _session: Session): def generate_sql(self, _session: Session): # append current question self.sql_message.append(HumanMessage( - self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')))) + self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title))) self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session, ai_modal_id=self.chat_question.ai_modal_id, @@ -756,6 +756,26 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]: return chart_type + @staticmethod + def get_brief_from_sql_answer(res: str) -> Optional[str]: + json_str = extract_nested_json(res) + if json_str is None: + return None + + brief: Optional[str] + data: dict + try: + data = orjson.loads(json_str) + + if data['success']: + brief = data['brief'] + else: + return None + except Exception: + return None + + return brief + def check_save_sql(self, session: Session, res: str) -> str: sql, *_ = self.check_sql(res=res) save_sql(session=session, sql=sql, record_id=self.record.id) @@ -925,17 +945,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True, if not stream: json_result['record_id'] = self.get_record().id - # return title - if self.change_title: - if self.chat_question.question and self.chat_question.question.strip() != '': - brief = rename_chat(session=_session, - rename_object=RenameChat(id=self.get_record().chat_id, - brief=self.chat_question.question.strip()[:20])) - if in_chat: - yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' - if not stream: - json_result['title'] = brief - # select datasource if datasource is none if not self.ds: ds_res = self.select_datasource(_session) @@ -981,6 +990,19 @@ def run_task(self, in_chat: bool = True, stream: bool = True, chart_type = self.get_chart_type_from_sql_answer(full_sql_text) + # return title + if self.change_title: + llm_brief = self.get_brief_from_sql_answer(full_sql_text) + if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''): + save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20] + brief = rename_chat(session=_session, + rename_object=RenameChat(id=self.get_record().chat_id, + brief=save_brief)) + if in_chat: + yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' + if not stream: + json_result['title'] = brief + use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type in dynamic_ds_types is_page_embedded: bool = self.current_assistant and self.current_assistant.type == 4 dynamic_sql_result = None @@ -1047,7 +1069,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, if in_chat: yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n' if not stream: - json_result['data'] = result.get('data') + json_result['data'] = get_chat_chart_data(_session, self.record.id) if finish_step.value <= ChatFinishStep.QUERY_DATA.value: if stream: diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py index e019465a..ed4d7387 100644 --- a/backend/apps/data_training/api/data_training.py +++ b/backend/apps/data_training/api/data_training.py @@ -17,6 +17,7 @@ from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans from common.utils.data_format import DataFormat +from common.utils.excel import get_excel_column_count router = APIRouter(tags=["DataTraining"], prefix="/system/data-training") @@ -73,8 +74,8 @@ def inner(): data_list.append(_data) fields = [] - fields.append(AxisObj(name=trans('i18n_data_training.data_training'), value='question')) - fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='description')) + fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='question')) + fields.append(AxisObj(name=trans('i18n_data_training.sample_sql'), value='description')) fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources'), value='datasource_name')) if current_user.oid == 1: fields.append( @@ -97,6 +98,43 @@ def inner(): return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") +@router.get("/template") +async def excel_template(trans: Trans, current_user: CurrentUser): + def inner(): + data_list = [] + _data1 = { + "question": '查询TEST表内所有ID', + "description": 'SELECT id FROM TEST', + "datasource_name": '生效数据源1', + "advanced_application_name": '生效高级应用名称', + } + data_list.append(_data1) + + fields = [] + fields.append(AxisObj(name=trans('i18n_data_training.problem_description_template'), value='question')) + fields.append(AxisObj(name=trans('i18n_data_training.sample_sql_template'), value='description')) + fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources_template'), value='datasource_name')) + if current_user.oid == 1: + fields.append( + AxisObj(name=trans('i18n_data_training.advanced_application_template'), value='advanced_application_name')) + + md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list) + + df = pd.DataFrame(md_data, columns=_fields_list) + + buffer = io.BytesIO() + + with pd.ExcelWriter(buffer, engine='xlsxwriter', + engine_kwargs={'options': {'strings_to_numbers': False}}) as writer: + df.to_excel(writer, sheet_name='Sheet1', index=False) + + buffer.seek(0) + return io.BytesIO(buffer.getvalue()) + + result = await asyncio.to_thread(inner) + return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") + + path = settings.EXCEL_PATH from sqlalchemy.orm import sessionmaker, scoped_session @@ -136,6 +174,9 @@ def inner(): for sheet_name in sheet_names: + if get_excel_column_count(save_path, sheet_name) < len(use_cols): + raise Exception(trans("i18n_excel_import.col_num_not_match")) + df = pd.read_excel( save_path, sheet_name=sheet_name, diff --git a/backend/apps/datasource/api/recommended_problem.py b/backend/apps/datasource/api/recommended_problem.py index 29fac41c..b7c75624 100644 --- a/backend/apps/datasource/api/recommended_problem.py +++ b/backend/apps/datasource/api/recommended_problem.py @@ -1,7 +1,10 @@ from fastapi import APIRouter -from apps.datasource.crud.recommended_problem import get_datasource_recommended -from common.core.deps import SessionDep +from apps.datasource.crud.datasource import update_ds_recommended_config +from apps.datasource.crud.recommended_problem import get_datasource_recommended, \ + save_recommended_problem +from apps.datasource.models.datasource import RecommendedProblemBase +from common.core.deps import SessionDep, CurrentUser router = APIRouter(tags=["recommended_problem"], prefix="/recommended_problem") @@ -10,3 +13,8 @@ async def datasource_recommended(session: SessionDep, ds_id: int): return get_datasource_recommended(session, ds_id) + +@router.post("/save_recommended_problem") +async def datasource_recommended(session: SessionDep, user: CurrentUser, data_info: RecommendedProblemBase): + update_ds_recommended_config(session, data_info.datasource_id, data_info.recommended_config) + return save_recommended_problem(session, user, data_info) diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index a6a867e1..e3a9a833 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -109,6 +109,11 @@ def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreData run_save_ds_embeddings([ds.id]) return ds +def update_ds_recommended_config(session: SessionDep,datasource_id: int, recommended_config:int): + record = session.exec(select(CoreDatasource).where(CoreDatasource.id == datasource_id)).first() + record.recommended_config = recommended_config + session.add(record) + session.commit() def delete_ds(session: SessionDep, id: int): term = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first() diff --git a/backend/apps/datasource/crud/recommended_problem.py b/backend/apps/datasource/crud/recommended_problem.py index e3f8ab67..21bbff71 100644 --- a/backend/apps/datasource/crud/recommended_problem.py +++ b/backend/apps/datasource/crud/recommended_problem.py @@ -1,11 +1,31 @@ +import datetime + from sqlmodel import select -from common.core.deps import SessionDep -from ..models.datasource import DsRecommendedProblem +from common.core.deps import SessionDep, CurrentUser, Trans +from ..models.datasource import DsRecommendedProblem, RecommendedProblemBase, RecommendedProblemBaseChat def get_datasource_recommended(session: SessionDep, ds_id: int): statement = select(DsRecommendedProblem).where(DsRecommendedProblem.datasource_id == ds_id) - dsRecommendedProblem = session.exec(statement) + dsRecommendedProblem = session.exec(statement).all() return dsRecommendedProblem +def get_datasource_recommended_chart(session: SessionDep, ds_id: int): + statement = select(DsRecommendedProblem.question).where(DsRecommendedProblem.datasource_id == ds_id) + dsRecommendedProblems = session.exec(statement).all() + return dsRecommendedProblems + +def save_recommended_problem(session: SessionDep,user: CurrentUser, data_info: RecommendedProblemBase): + session.query(DsRecommendedProblem).filter(DsRecommendedProblem.datasource_id == data_info.datasource_id).delete(synchronize_session=False) + problemInfo = data_info.problemInfo + if problemInfo is not None: + for problemItem in problemInfo: + problemItem.id = None + problemItem.create_time = datetime.datetime.now() + problemItem.create_by = user.id + record = DsRecommendedProblem(**problemItem.model_dump()) + session.add(record) + session.flush() + session.refresh(record) + session.commit() diff --git a/backend/apps/datasource/models/datasource.py b/backend/apps/datasource/models/datasource.py index 2cdfed5a..dd814ea4 100644 --- a/backend/apps/datasource/models/datasource.py +++ b/backend/apps/datasource/models/datasource.py @@ -74,6 +74,18 @@ class CreateDatasource(BaseModel): tables: List[CoreTable] = [] recommended_config: int = 1 +class RecommendedProblemBase(BaseModel): + datasource_id: int = None + recommended_config: int = None + problemInfo: List[DsRecommendedProblem] = [] + + +class RecommendedProblemBaseChat: + def __init__(self, content): + self.content = content + + content: List[str] = [] + # edit local saved table and fields class TableObj(BaseModel): diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index 74d2e6b1..51084236 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -24,7 +24,7 @@ from apps.datasource.utils.utils import aes_decrypt from apps.db.constant import DB, ConnectType from apps.db.engine import get_engine_config -from apps.system.crud.assistant import get_ds_engine +from apps.system.crud.assistant import get_out_ds_conf from apps.system.schemas.system_schema import AssistantOutDsSchema from common.core.deps import Trans from common.utils.utils import SQLBotLogUtil, equals_ignore_case @@ -146,92 +146,25 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: def get_session(ds: CoreDatasource | AssistantOutDsSchema): - engine = get_engine(ds) if isinstance(ds, CoreDatasource) else get_ds_engine(ds) + # engine = get_engine(ds) if isinstance(ds, CoreDatasource) else get_ds_engine(ds) + if isinstance(ds, AssistantOutDsSchema): + out_conf = get_out_ds_conf(ds, 30) + ds.configuration = out_conf + + engine = get_engine(ds) session_maker = sessionmaker(bind=engine) session = session_maker() return session def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDsSchema, is_raise: bool = False): - if isinstance(ds, CoreDatasource): - db = DB.get_db(ds.type) - if db.connect_type == ConnectType.sqlalchemy: - conn = get_engine(ds, 10) - try: - with conn.connect() as connection: - SQLBotLogUtil.info("success") - return True - except Exception as e: - SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") - if is_raise: - raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') - return False - else: - conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) - extra_config_dict = get_extra_config(conf) - if equals_ignore_case(ds.type, 'dm'): - with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: - try: - cursor.execute('select 1', timeout=10).fetchall() - SQLBotLogUtil.info("success") - return True - except Exception as e: - SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") - if is_raise: - raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') - return False - elif equals_ignore_case(ds.type, 'doris', 'starrocks'): - with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, - port=conf.port, db=conf.database, connect_timeout=10, - read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: - try: - cursor.execute('select 1') - SQLBotLogUtil.info("success") - return True - except Exception as e: - SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") - if is_raise: - raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') - return False - elif equals_ignore_case(ds.type, 'redshift'): - with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, - user=conf.username, - password=conf.password, - timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: - try: - cursor.execute('select 1') - SQLBotLogUtil.info("success") - return True - except Exception as e: - SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") - if is_raise: - raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') - return False - elif equals_ignore_case(ds.type, 'kingbase'): - with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, - user=conf.username, - password=conf.password, - connect_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: - try: - cursor.execute('select 1') - SQLBotLogUtil.info("success") - return True - except Exception as e: - SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") - if is_raise: - raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') - return False - elif equals_ignore_case(ds.type, 'es'): - es_conn = get_es_connect(conf) - if es_conn.ping(): - SQLBotLogUtil.info("success") - return True - else: - SQLBotLogUtil.info("failed") - return False - else: - conn = get_ds_engine(ds) + if isinstance(ds, AssistantOutDsSchema): + out_conf = get_out_ds_conf(ds, 10) + ds.configuration = out_conf + + db = DB.get_db(ds.type) + if db.connect_type == ConnectType.sqlalchemy: + conn = get_engine(ds, 10) try: with conn.connect() as connection: SQLBotLogUtil.info("success") @@ -241,26 +174,102 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False + else: + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) + extra_config_dict = get_extra_config(conf) + if equals_ignore_case(ds.type, 'dm'): + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1', timeout=10).fetchall() + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif equals_ignore_case(ds.type, 'doris', 'starrocks'): + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=10, + read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif equals_ignore_case(ds.type, 'redshift'): + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, + user=conf.username, + password=conf.password, + timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif equals_ignore_case(ds.type, 'kingbase'): + with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, + user=conf.username, + password=conf.password, + connect_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif equals_ignore_case(ds.type, 'es'): + es_conn = get_es_connect(conf) + if es_conn.ping(): + SQLBotLogUtil.info("success") + return True + else: + SQLBotLogUtil.info("failed") + return False + # else: + # conn = get_ds_engine(ds) + # try: + # with conn.connect() as connection: + # SQLBotLogUtil.info("success") + # return True + # except Exception as e: + # SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + # if is_raise: + # raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + # return False return False def get_version(ds: CoreDatasource | AssistantOutDsSchema): version = '' - conf = None if isinstance(ds, CoreDatasource): conf = DatasourceConf( **json.loads(aes_decrypt(ds.configuration))) if not equals_ignore_case(ds.type, "excel") else get_engine_config() - if isinstance(ds, AssistantOutDsSchema): - conf = DatasourceConf() - conf.host = ds.host - conf.port = ds.port - conf.username = ds.user - conf.password = ds.password - conf.database = ds.dataBase - conf.dbSchema = ds.db_schema - conf.timeout = 10 + else: + conf = DatasourceConf(**json.loads(aes_decrypt(get_out_ds_conf(ds, 10)))) + # if isinstance(ds, AssistantOutDsSchema): + # conf = DatasourceConf() + # conf.host = ds.host + # conf.port = ds.port + # conf.username = ds.user + # conf.password = ds.password + # conf.database = ds.dataBase + # conf.dbSchema = ds.db_schema + # conf.timeout = 10 db = DB.get_db(ds.type) sql = get_version_sql(ds, conf) try: diff --git a/backend/apps/system/api/assistant.py b/backend/apps/system/api/assistant.py index 9733d3f0..c6036142 100644 --- a/backend/apps/system/api/assistant.py +++ b/backend/apps/system/api/assistant.py @@ -17,7 +17,7 @@ from common.core.deps import SessionDep, Trans from common.core.security import create_access_token from common.core.sqlbot_cache import clear_cache -from common.utils.utils import get_origin_from_referer +from common.utils.utils import get_origin_from_referer, origin_match_domain router = APIRouter(tags=["system/assistant"], prefix="/system/assistant") @@ -30,13 +30,15 @@ async def info(request: Request, response: Response, session: SessionDep, trans: if not db_model: raise RuntimeError(f"assistant application not exist") db_model = AssistantModel.model_validate(db_model) - response.headers["Access-Control-Allow-Origin"] = db_model.domain + origin = request.headers.get("origin") or get_origin_from_referer(request) if not origin: raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) origin = origin.rstrip('/') - if origin != db_model.domain: + if not origin_match_domain(origin, db_model.domain): raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) + + response.headers["Access-Control-Allow-Origin"] = origin return db_model @@ -48,13 +50,14 @@ async def getApp(request: Request, response: Response, session: SessionDep, tran if not db_model: raise RuntimeError(f"assistant application not exist") db_model = AssistantModel.model_validate(db_model) - response.headers["Access-Control-Allow-Origin"] = db_model.domain origin = request.headers.get("origin") or get_origin_from_referer(request) if not origin: raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) origin = origin.rstrip('/') - if origin != db_model.domain: + if not origin_match_domain(origin, db_model.domain): raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or '')) + + response.headers["Access-Control-Allow-Origin"] = origin return db_model diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 72168221..5d731847 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -10,6 +10,7 @@ # from apps.datasource.embedding.table_embedding import get_table_embedding from apps.datasource.models.datasource import CoreDatasource, DatasourceConf +from apps.datasource.utils.utils import aes_encrypt from apps.system.models.system_model import AssistantModel from apps.system.schemas.auth import CacheName, CacheNamespace from apps.system.schemas.system_schema import AssistantHeader, AssistantOutDsSchema, UserInfoDTO @@ -99,17 +100,22 @@ class AssistantOutDs: assistant: AssistantHeader ds_list: Optional[list[AssistantOutDsSchema]] = None certificate: Optional[str] = None + request_origin: Optional[str] = None def __init__(self, assistant: AssistantHeader): self.assistant = assistant self.ds_list = None self.certificate = assistant.certificate + self.request_origin = assistant.request_origin self.get_ds_from_api() # @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id") def get_ds_from_api(self): config: dict[any] = json.loads(self.assistant.configuration) endpoint: str = config['endpoint'] + endpoint = self.get_complete_endpoint(endpoint=endpoint) + if not endpoint: + raise Exception(f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]") certificateList: list[any] = json.loads(self.certificate) header = {} cookies = {} @@ -137,6 +143,17 @@ def get_ds_from_api(self): else: raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}") + def get_complete_endpoint(self, endpoint: str) -> str | None: + if endpoint.startswith("http://") or endpoint.startswith("https://"): + return endpoint + domain_text = self.assistant.domain + if not domain_text: + return None + if ',' in domain_text: + return (self.request_origin.strip('/') if self.request_origin else domain_text.split(',')[0].strip('/')) + endpoint + else: + return f"{domain_text}{endpoint}" + def get_simple_ds_list(self): if self.ds_list: return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list] @@ -153,7 +170,7 @@ def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> st for table in ds.tables: i += 1 schema_table = '' - schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}" + schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {table.name}" table_comment = table.comment if table_comment == '': schema_table += '\n[\n' @@ -250,3 +267,19 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine: else: engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout) return engine + + +def get_out_ds_conf(ds: AssistantOutDsSchema, timeout:int=30) -> str: + conf = { + "host":ds.host or '', + "port":ds.port or 0, + "username":ds.user or '', + "password":ds.password or '', + "database":ds.dataBase or '', + "driver":'', + "extraJdbc":ds.extraParams or '', + "dbSchema":ds.db_schema or '', + "timeout":timeout or 30 + } + conf["extraJdbc"] = '' + return aes_encrypt(json.dumps(conf)) diff --git a/backend/apps/system/middleware/auth.py b/backend/apps/system/middleware/auth.py index 3ea720c6..ecc7b416 100644 --- a/backend/apps/system/middleware/auth.py +++ b/backend/apps/system/middleware/auth.py @@ -16,7 +16,7 @@ from common.core.config import settings from common.core.schemas import TokenPayload from common.utils.locale import I18n -from common.utils.utils import SQLBotLogUtil +from common.utils.utils import SQLBotLogUtil, get_origin_from_referer from common.utils.whitelist import whiteUtils from fastapi.security.utils import get_authorization_scheme_param from common.core.deps import get_i18n @@ -40,6 +40,9 @@ async def dispatch(self, request, call_next): if validator[0]: request.state.current_user = validator[1] request.state.assistant = validator[2] + origin = request.headers.get("X-SQLBOT-HOST-ORIGIN") or get_origin_from_referer(request) + if origin and validator[2]: + request.state.assistant.request_origin = origin return await call_next(request) message = trans('i18n_permission.authenticate_invalid', msg = validator[1]) return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"}) diff --git a/backend/apps/system/schemas/system_schema.py b/backend/apps/system/schemas/system_schema.py index 52dc20c9..f16f6dd3 100644 --- a/backend/apps/system/schemas/system_schema.py +++ b/backend/apps/system/schemas/system_schema.py @@ -116,6 +116,7 @@ class AssistantHeader(AssistantDTO): unique: Optional[str] = None certificate: Optional[str] = None online: bool = False + request_origin: Optional[str] = None class AssistantValidator(BaseModel): @@ -177,6 +178,7 @@ class AssistantOutDsBase(BaseModel): type_name: Optional[str] = None comment: Optional[str] = None description: Optional[str] = None + configuration: Optional[str] = None class AssistantOutDsSchema(AssistantOutDsBase): diff --git a/backend/apps/terminology/api/terminology.py b/backend/apps/terminology/api/terminology.py index 34db52c5..16cb6cff 100644 --- a/backend/apps/terminology/api/terminology.py +++ b/backend/apps/terminology/api/terminology.py @@ -17,15 +17,17 @@ from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans from common.utils.data_format import DataFormat +from common.utils.excel import get_excel_column_count router = APIRouter(tags=["Terminology"], prefix="/system/terminology") @router.get("/page/{current_page}/{page_size}") async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int, - word: Optional[str] = Query(None, description="搜索术语(可选)")): + word: Optional[str] = Query(None, description="搜索术语(可选)"), + dslist: Optional[list[int]] = Query(None, description="数据集ID集合(可选)")): current_page, page_size, total_count, total_pages, _list = page_terminology(session, current_page, page_size, word, - current_user.oid) + current_user.oid, dslist) return { "current_page": current_page, @@ -96,6 +98,51 @@ def inner(): return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") +@router.get("/template") +async def excel_template(trans: Trans): + def inner(): + data_list = [] + _data1 = { + "word": trans('i18n_terminology.term_name_template_example_1'), + "other_words": trans('i18n_terminology.synonyms_template_example_1'), + "description": trans('i18n_terminology.term_description_template_example_1'), + "all_data_sources": 'N', + "datasource": trans('i18n_terminology.effective_data_sources_template_example_1'), + } + data_list.append(_data1) + _data2 = { + "word": trans('i18n_terminology.term_name_template_example_2'), + "other_words": trans('i18n_terminology.synonyms_template_example_2'), + "description": trans('i18n_terminology.term_description_template_example_2'), + "all_data_sources": 'Y', + "datasource": '', + } + data_list.append(_data2) + + fields = [] + fields.append(AxisObj(name=trans('i18n_terminology.term_name_template'), value='word')) + fields.append(AxisObj(name=trans('i18n_terminology.synonyms_template'), value='other_words')) + fields.append(AxisObj(name=trans('i18n_terminology.term_description_template'), value='description')) + fields.append(AxisObj(name=trans('i18n_terminology.effective_data_sources_template'), value='datasource')) + fields.append(AxisObj(name=trans('i18n_terminology.all_data_sources_template'), value='all_data_sources')) + + md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list) + + df = pd.DataFrame(md_data, columns=_fields_list) + + buffer = io.BytesIO() + + with pd.ExcelWriter(buffer, engine='xlsxwriter', + engine_kwargs={'options': {'strings_to_numbers': False}}) as writer: + df.to_excel(writer, sheet_name='Sheet1', index=False) + + buffer.seek(0) + return io.BytesIO(buffer.getvalue()) + + result = await asyncio.to_thread(inner) + return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") + + path = settings.EXCEL_PATH from sqlalchemy.orm import sessionmaker, scoped_session @@ -132,6 +179,9 @@ def inner(): for sheet_name in sheet_names: + if get_excel_column_count(save_path, sheet_name) < len(use_cols): + raise Exception(trans("i18n_excel_import.col_num_not_match")) + df = pd.read_excel( save_path, sheet_name=sheet_name, diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index 282041fc..4a866cbe 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -55,12 +55,32 @@ def get_terminology_base_query(oid: int, name: Optional[str] = None): def build_terminology_query(session: SessionDep, oid: int, name: Optional[str] = None, - paginate: bool = True, current_page: int = 1, page_size: int = 10): + paginate: bool = True, current_page: int = 1, page_size: int = 10, + dslist: Optional[list[int]] = None): """ 构建术语查询的通用方法 """ parent_ids_subquery, child = get_terminology_base_query(oid, name) + # 添加数据源筛选条件 + if dslist is not None and len(dslist) > 0: + datasource_conditions = [] + # datasource_ids 与 dslist 中的任一元素有交集 + for ds_id in dslist: + # 使用 JSONB 包含操作符,但需要确保类型正确 + datasource_conditions.append( + Terminology.datasource_ids.contains([ds_id]) + ) + + # datasource_ids 为空数组 + empty_array_condition = Terminology.datasource_ids == [] + + ds_filter_condition = or_( + *datasource_conditions, + empty_array_condition + ) + parent_ids_subquery = parent_ids_subquery.where(ds_filter_condition) + # 计算总数 count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) total_count = session.execute(count_stmt).scalar() @@ -176,12 +196,12 @@ def execute_terminology_query(session: SessionDep, stmt) -> List[TerminologyInfo def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, - name: Optional[str] = None, oid: Optional[int] = 1): + name: Optional[str] = None, oid: Optional[int] = 1, dslist: Optional[list[int]] = None): """ 分页查询术语(原方法保持不变) """ stmt, total_count, total_pages, current_page, page_size = build_terminology_query( - session, oid, name, True, current_page, page_size + session, oid, name, True, current_page, page_size, dslist ) _list = execute_terminology_query(session, stmt) @@ -773,7 +793,9 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasou for row in t_list: pid = str(row.pid) if row.pid is not None else str(row.id) if _map.get(pid) is None: - _map[pid] = {'words': [], 'description': row.description} + _map[pid] = {'words': [], 'description': ''} + if row.pid is None: + _map[pid]['description'] = row.description _map[pid]['words'].append(row.word) _results: list[dict] = [] diff --git a/backend/common/utils/excel.py b/backend/common/utils/excel.py new file mode 100644 index 00000000..659428ab --- /dev/null +++ b/backend/common/utils/excel.py @@ -0,0 +1,12 @@ +import pandas as pd + +def get_excel_column_count(file_path, sheet_name): + """获取Excel文件的列数""" + df_temp = pd.read_excel( + file_path, + sheet_name=sheet_name, + engine='calamine', + header=0, + nrows=0 + ) + return len(df_temp.columns) \ No newline at end of file diff --git a/backend/common/utils/utils.py b/backend/common/utils/utils.py index ee6b0964..ec8c0d85 100644 --- a/backend/common/utils/utils.py +++ b/backend/common/utils/utils.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta, timezone from logging.handlers import RotatingFileHandler from pathlib import Path +import re from urllib.parse import urlparse from fastapi import Request @@ -263,6 +264,17 @@ def get_origin_from_referer(request: Request): SQLBotLogUtil.error(f"解析 Referer 出错: {e}") return referer +def origin_match_domain(origin: str, domain: str) -> bool: + if not origin or not domain: + return False + origin_normalized = origin.rstrip('/') + + for d in re.split(r'[,;]', domain): + if d.strip().rstrip('/') == origin_normalized: + return True + + return False + def equals_ignore_case(str1: str, *args: str) -> bool: if str1 is None: diff --git a/backend/locales/en.json b/backend/locales/en.json index f6402bbc..c43d4fb8 100644 --- a/backend/locales/en.json +++ b/backend/locales/en.json @@ -48,6 +48,18 @@ "effective_data_sources": "Effective Data Sources", "all_data_sources": "All Data Sources", "synonyms": "Synonyms", + "term_name_template": "Terminology Name (Required)", + "term_description_template": "Terminology Description (Required)", + "effective_data_sources_template": "Effective Data Sources (Multiple supported, separated by \",\")", + "all_data_sources_template": "All Data Sources (Y: Apply to all data sources, N: Apply to specified data sources)", + "synonyms_template": "Synonyms (Multiple supported, separated by \",\")", + "term_name_template_example_1": "Term1", + "term_description_template_example_1": "Term1 Description", + "effective_data_sources_template_example_1": "Datasource1, Datasource2", + "synonyms_template_example_1": "Synonym1, Synonym2", + "term_name_template_example_2": "Term2", + "term_description_template_example_2": "Term2 Description", + "synonyms_template_example_2": "Synonym3", "word_cannot_be_empty": "Term cannot be empty", "description_cannot_be_empty": "Term description cannot be empty", "datasource_not_found": "Datasource not found" @@ -62,6 +74,13 @@ "sample_sql": "Sample SQL", "effective_data_sources": "Effective Data Sources", "advanced_application": "Advanced Application", + "problem_description_template": "Problem Description (Required)", + "sample_sql_template": "Sample SQL (Required)", + "effective_data_sources_template": "Effective Data Sources", + "advanced_application_template": "Advanced Application", + "problem_description_template_example": "Query all IDs in the TEST table", + "effective_data_sources_template_example": "Effective Datasource 1", + "advanced_application_template_example": "Effective Advanced Application Name", "error_info": "Error Information", "question_cannot_be_empty": "Question cannot be empty", "description_cannot_be_empty": "Sample SQL cannot be empty", @@ -75,6 +94,15 @@ "prompt_word_content": "Prompt word content", "effective_data_sources": "Effective Data Sources", "all_data_sources": "All Data Sources", + "prompt_word_name_template": "Prompt Name (Required)", + "prompt_word_content_template": "Prompt Content (Required)", + "effective_data_sources_template": "Effective Data Sources (Multiple supported, separated by \",\")", + "all_data_sources_template": "All Data Sources (Y: Apply to all data sources, N: Apply to specified data sources)", + "prompt_word_name_template_example1": "Prompt1", + "prompt_word_content_template_example1": "Describe your prompt in detail", + "effective_data_sources_template_example1": "Datasource1, Datasource2", + "prompt_word_name_template_example2": "Prompt2", + "prompt_word_content_template_example2": "Describe your prompt in detail", "name_cannot_be_empty": "Name cannot be empty", "prompt_cannot_be_empty": "Prompt content cannot be empty", "type_cannot_be_empty": "Type cannot be empty", @@ -83,5 +111,8 @@ }, "i18n_excel_export": { "data_is_empty": "Form data is empty, unable to export data" + }, + "i18n_excel_import": { + "col_num_not_match": "Number of columns in Excel does not match" } } \ No newline at end of file diff --git a/backend/locales/ko-KR.json b/backend/locales/ko-KR.json index 3c7621c8..87d883f9 100644 --- a/backend/locales/ko-KR.json +++ b/backend/locales/ko-KR.json @@ -48,6 +48,18 @@ "effective_data_sources": "유효 데이터 소스", "all_data_sources": "모든 데이터 소스", "synonyms": "동의어", + "term_name_template": "용어 이름 (필수)", + "term_description_template": "용어 설명 (필수)", + "effective_data_sources_template": "유효 데이터 소스 (여러 개 지원, \",\"로 구분)", + "all_data_sources_template": "모든 데이터 소스 (Y: 모든 데이터 소스에 적용, N: 지정된 데이터 소스에 적용)", + "synonyms_template": "동의어 (여러 개 지원, \",\"로 구분)", + "term_name_template_example_1": "용어1", + "term_description_template_example_1": "용어1 설명", + "effective_data_sources_template_example_1": "데이터소스1, 데이터소스2", + "synonyms_template_example_1": "동의어1, 동의어2", + "term_name_template_example_2": "용어2", + "term_description_template_example_2": "용어2 설명", + "synonyms_template_example_2": "동의어3", "word_cannot_be_empty": "용어는 비울 수 없습니다", "description_cannot_be_empty": "용어 설명은 비울 수 없습니다", "datasource_not_found": "데이터 소스를 찾을 수 없음" @@ -62,6 +74,13 @@ "sample_sql": "예시 SQL", "effective_data_sources": "유효 데이터 소스", "advanced_application": "고급 애플리케이션", + "problem_description_template": "문제 설명 (필수)", + "sample_sql_template": "예시 SQL (필수)", + "effective_data_sources_template": "유효 데이터 소스", + "advanced_application_template": "고급 애플리케이션", + "problem_description_template_example": "TEST 테이블 내 모든 ID 조회", + "effective_data_sources_template_example": "유효 데이터소스1", + "advanced_application_template_example": "유효 고급 애플리케이션 이름", "error_info": "오류 정보", "question_cannot_be_empty": "질문은 비울 수 없습니다", "description_cannot_be_empty": "예시 SQL은 비울 수 없습니다", @@ -75,6 +94,15 @@ "prompt_word_content": "프롬프트 내용", "effective_data_sources": "유효 데이터 소스", "all_data_sources": "모든 데이터 소스", + "prompt_word_name_template": "프롬프트 이름 (필수)", + "prompt_word_content_template": "프롬프트 내용 (필수)", + "effective_data_sources_template": "유효 데이터 소스 (여러 개 지원, \",\"로 구분)", + "all_data_sources_template": "모든 데이터 소스 (Y: 모든 데이터 소스에 적용, N: 지정된 데이터 소스에 적용)", + "prompt_word_name_template_example1": "프롬프트1", + "prompt_word_content_template_example1": "프롬프트를 상세히 설명해 주세요", + "effective_data_sources_template_example1": "데이터소스1, 데이터소스2", + "prompt_word_name_template_example2": "프롬프트2", + "prompt_word_content_template_example2": "프롬프트를 상세히 설명해 주세요", "name_cannot_be_empty": "이름은 비울 수 없습니다", "prompt_cannot_be_empty": "프롬프트 내용은 비울 수 없습니다", "type_cannot_be_empty": "유형은 비울 수 없습니다", @@ -83,5 +111,8 @@ }, "i18n_excel_export": { "data_is_empty": "폼 데이터가 비어 있어 데이터를 내보낼 수 없습니다" + }, + "i18n_excel_import": { + "col_num_not_match": "Excel 열 개수가 일치하지 않습니다" } } \ No newline at end of file diff --git a/backend/locales/zh-CN.json b/backend/locales/zh-CN.json index de454ba8..3393f402 100644 --- a/backend/locales/zh-CN.json +++ b/backend/locales/zh-CN.json @@ -48,6 +48,18 @@ "effective_data_sources": "生效数据源", "all_data_sources": "所有数据源", "synonyms": "同义词", + "term_name_template": "术语名称(必填)", + "term_description_template": "术语描述(必填)", + "effective_data_sources_template": "生效数据源(支持多个,用\",\"分割)", + "all_data_sources_template": "所有数据源(Y:应用到全部数据源,N:应用到指定数据源)", + "synonyms_template": "同义词(支持多个,用\",\"分割)", + "term_name_template_example_1": "术语1", + "term_description_template_example_1": "术语1描述", + "effective_data_sources_template_example_1": "生效数据源1, 生效数据源2", + "synonyms_template_example_1": "同义词1, 同义词2", + "term_name_template_example_2": "术语2", + "term_description_template_example_2": "术语2描述", + "synonyms_template_example_2": "同义词3", "word_cannot_be_empty": "术语不能为空", "description_cannot_be_empty": "术语描述不能为空", "datasource_not_found": "找不到数据源" @@ -62,6 +74,13 @@ "sample_sql": "示例 SQL", "effective_data_sources": "生效数据源", "advanced_application": "高级应用", + "problem_description_template": "问题描述(必填)", + "sample_sql_template": "示例 SQL(必填)", + "effective_data_sources_template": "生效数据源", + "advanced_application_template": "高级应用", + "problem_description_template_example": "查询TEST表内所有ID", + "effective_data_sources_template_example": "生效数据源1", + "advanced_application_template_example": "生效高级应用名称", "error_info": "错误信息", "question_cannot_be_empty": "问题不能为空", "description_cannot_be_empty": "示例 SQL 不能为空", @@ -75,6 +94,15 @@ "prompt_word_content": "提示词内容", "effective_data_sources": "生效数据源", "all_data_sources": "所有数据源", + "prompt_word_name_template": "提示词名称(必填)", + "prompt_word_content_template": "提示词内容(必填)", + "effective_data_sources_template": "生效数据源(支持多个,用\",\"分割)", + "all_data_sources_template": "所有数据源(Y:应用到全部数据源,N:应用到指定数据源)", + "prompt_word_name_template_example1": "提示词1", + "prompt_word_content_template_example1": "详细描述你的提示词", + "effective_data_sources_template_example1": "生效数据源1, 生效数据源2", + "prompt_word_name_template_example2": "提示词2", + "prompt_word_content_template_example2": "详细描述你的提示词", "name_cannot_be_empty": "名称不能为空", "prompt_cannot_be_empty": "提示词内容不能为空", "type_cannot_be_empty": "类型不能为空", @@ -83,5 +111,8 @@ }, "i18n_excel_export": { "data_is_empty": "表单数据为空,无法导出数据" + }, + "i18n_excel_import": { + "col_num_not_match": "EXCEL列数量不匹配" } } \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index cc242872..ca220509 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "pyyaml (>=6.0.2,<7.0.0)", "fastapi-mcp (>=0.3.4,<0.4.0)", "tabulate>=0.9.0", - "sqlbot-xpack>=0.0.3.47,<1.0.0", + "sqlbot-xpack>=0.0.3.53,<1.0.0", "fastapi-cache2>=0.2.2", "sqlparse>=0.5.3", "redis>=6.2.0", @@ -51,6 +51,7 @@ dependencies = [ "dmpython==2.5.22; platform_system != 'Darwin'", "redshift-connector>=2.1.8", "elasticsearch[requests] (>=7.10,<8.0)", + "ldap3>=2.9.1", ] [project.optional-dependencies] diff --git a/backend/templates/sql_examples/Oracle.yaml b/backend/templates/sql_examples/Oracle.yaml index b7293465..ba3015b2 100644 --- a/backend/templates/sql_examples/Oracle.yaml +++ b/backend/templates/sql_examples/Oracle.yaml @@ -1,4 +1,18 @@ template: + process_check: | + + 1. 分析用户问题,确定查询需求 + 2. 根据表结构生成基础SQL + 3. 强制检查:SQL是否包含GROUP BY/聚合函数? + 4. 如果是GROUP BY查询:必须使用外层查询结构包裹 + 5. 强制检查:应用数据量限制规则 + 6. 应用其他规则(引号、别名等) + 7. 最终验证:GROUP BY查询的ROWNUM位置是否正确? + 8. 强制检查:检查语法是否正确? + 9. 确定图表类型 + 10. 返回JSON结果 + + quot_rule: | 必须对数据库名、表名、字段名、别名外层加双引号(")。 @@ -10,42 +24,61 @@ template: limit_rule: | - - 当需要限制行数时: - 1. 12c以下版本必须使用ROWNUM语法 - 2. 12c+版本推荐使用FETCH FIRST语法 - - 版本适配: - - Oracle 12c以下:必须使用 WHERE ROWNUM <= N - - Oracle 12c+:推荐使用 FETCH FIRST N ROWS ONLY - - - 重要:ROWNUM必须放在正确的位置,避免语法错误 - 1. 单层查询:ROWNUM直接跟在WHERE子句后 - - 2. 多层查询:ROWNUM只能放在最外层 - - 3. 禁止的错误写法: - - SELECT ... FROM table - WHERE conditions - GROUP BY ... - ORDER BY ... - WHERE ROWNUM <= N -- 错误:不能有多个WHERE - - 4. 正确顺序:WHERE → GROUP BY → HAVING → ORDER BY → ROWNUM - 5. 括号位置:从内层SELECT开始到内层结束都要括起来 - - SELECT ... FROM ( - -- 内层完整查询(包含自己的SELECT、FROM、WHERE、GROUP BY、ORDER BY) - SELECT columns FROM table WHERE conditions GROUP BY ... ORDER BY ... - ) alias WHERE ROWNUM <= N - - + + Oracle版本语法适配 + + 如果db-engine版本号小于12 + 必须使用ROWNUM语法 + 如果db-engine版本号大于等于12 + 推荐使用FETCH FIRST语法 + + + + Oracle数据库 FETCH FIRST 语法规范 + 若使用 FETCH FIRST 语法,则必须遵循该规范 + + + + + + Oracle数据库ROWNUM语法规范 + 若使用ROWNUM语法,则必须遵循该规范 + + + 简单查询 + + + + 语法禁区 + + - 禁止多个WHERE子句 + - 禁止ROWNUM在GROUP BY内层(影响分组结果) + - 禁止括号不完整 + + + + + + GROUP BY查询的ROWNUM强制规范(必须严格遵守) + 所有包含GROUP BY或聚合函数的查询必须使用外层查询结构 + ROWNUM必须放在最外层查询的WHERE子句中 + + + 如果SQL包含GROUP BY、COUNT、SUM等聚合函数 + 必须使用:SELECT ... FROM (内层完整查询) WHERE ROWNUM <= N + 否则(简单查询) + 可以使用:SELECT ... FROM table WHERE conditions AND ROWNUM <= N + + + + -- 错误:ROWNUM在内层影响分组结果 + SELECT ... GROUP BY ... WHERE ROWNUM <= N + + + + -- 正确:ROWNUM在外层 + SELECT ... FROM (SELECT ... GROUP BY ...) WHERE ROWNUM <= N + other_rule: | @@ -73,7 +106,7 @@ template: SELECT "订单ID", "金额" FROM "TEST"."ORDERS" "t1" WHERE ROWNUM <= 100 -- 错误:缺少英文别名 SELECT COUNT("订单ID") FROM "TEST"."ORDERS" "t1" -- 错误:函数未加别名 - + SELECT "t1"."订单ID" AS "order_id", "t1"."金额" AS "amount", @@ -90,7 +123,7 @@ template: SELECT DATE, status FROM PUBLIC.USERS -- 错误:未处理关键字和引号 SELECT "DATE", ROUND(active_ratio) FROM "PUBLIC"."USERS" -- 错误:百分比格式错误 - + SELECT "u"."DATE" AS "create_date", TO_CHAR("u"."active_ratio" * 100, '990.99') || '%' AS "active_percent" @@ -98,6 +131,14 @@ template: WHERE "u"."status" = 1 AND ROWNUM <= 1000 + + SELECT + "u"."DATE" AS "create_date", + TO_CHAR("u"."active_ratio" * 100, '990.99') || '%' AS "active_percent" + FROM "PUBLIC"."USERS" "u" + WHERE "u"."status" = 1 + FETCH FIRST 1000 ROWS ONLY + @@ -108,9 +149,9 @@ template: count(*) AS "user_count" FROM "PUBLIC"."USERS" "u" WHERE "u"."status" = 1 - AND ROWNUM <= 100 + AND ROWNUM <= 100 -- 严重错误:影响分组结果! GROUP BY "u"."DEPARTMENT" - ORDER BY "department_name" -- 错误:ROWNUM 应当写在最外层,这样会导致查询结果条数比实际数据的数量少 + ORDER BY "department_name" SELECT "department_name", "user_count" FROM @@ -123,7 +164,7 @@ template: ORDER BY "department_name" WHERE ROWNUM <= 100 -- 错误:语法错误,同级内只能有一个WHERE - + SELECT "department_name", "user_count" FROM ( SELECT "u"."DEPARTMENT" AS "department_name", @@ -133,12 +174,22 @@ template: GROUP BY "u"."DEPARTMENT" ORDER BY "department_name" ) - WHERE ROWNUM <= 100 -- 外层限制(确保最终结果可控) + WHERE ROWNUM <= 100 -- 正确,在外层限制数量(确保最终结果可控) + + + SELECT + "u"."DEPARTMENT" AS "department_name", + count(*) AS "user_count" + FROM "PUBLIC"."USERS" "u" + WHERE "u"."status" = 1 + GROUP BY "u"."DEPARTMENT" + ORDER BY "department_name" + FETCH FIRST 100 ROWS ONLY - example_engine: Oracle 19c + example_engine: Oracle 11g example_answer_1: | {"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\"","tables":["sample_country_gdp"],"chart-type":"line"} example_answer_1_with_limit: | diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index 7e15db29..8f51253b 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -6,19 +6,43 @@ template: {data_training} sql: + process_check: | + + 1. 分析用户问题,确定查询需求 + 2. 根据表结构生成基础SQL + 3. 强制检查:应用数据量限制规则 + 4. 应用其他规则(引号、别名等) + 5. 强制检查:检查语法是否正确? + 6. 确定图表类型 + 7. 确定对话标题 + 8. 返回JSON结果 + query_limit: | - - 1. 必须遵守:所有生成的SQL必须包含数据量限制 - 2. 默认限制:1000条(除非用户明确指定其他数量) + + 数据量限制策略(必须严格遵守 - 零容忍) + + 所有生成的SQL必须包含数据量限制,这是强制要求 + 默认限制:1000条(除非用户明确指定其他数量) + 忘记添加数据量限制是不可接受的错误 + + + + 如果生成的SQL没有数据量限制,必须重新生成 + 在最终返回前必须验证限制是否存在 + no_query_limit: | - - 如果没有指定数据条数的限制,则查询的SQL默认返回全部数据 + + 数据量限制策略(必须严格遵守) + + 默认不限制数据量,返回全部数据(除非用户明确指定其他数量) + 不要臆测场景可能需要的数据量限制,以用户明确指定的数量为准 + system: | 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 - 你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。 + 你当前的任务是根据给定的表结构和用户问题生成SQL语句、对话标题、可能适合展示的图表类型以及该SQL中所用到的表名。 我们会在块内提供给你信息,帮助你生成SQL: 内有等信息; 其中,:提供数据库引擎及版本信息; @@ -27,9 +51,10 @@ template: :提供一组SQL示例,你可以参考这些示例来生成你的回答,其中内是提问,内是对于该提问的解释或者对应应该回答的SQL示例。 若有块,它会提供一组,可能会是额外添加的背景信息,或者是额外的生成SQL的要求,请结合额外信息或要求后生成你的回答。 用户的提问在内,内则会提供上次执行你提供的SQL时会出现的错误信息,内的会告诉你用户当前提问的时间 + 你必须遵守内规定的生成SQL规则 + 你必须遵守内规定的检查步骤生成你的回答 - 你必须遵守以下规则: 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 @@ -48,7 +73,7 @@ template: 请使用JSON格式返回你的回答: - 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} + 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table","brief":"如何需要生成对话标题,在这里填写你生成的对话标题,否则不需要这个字段"}} 若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}} @@ -88,8 +113,13 @@ template: 我们目前的情况适用于单指标、多分类的场景(展示table除外) + + 是否生成对话标题在内,如果为True需要生成,否则不需要生成,生成的对话标题要求在20字以内 + + {process_check} + {basic_sql_examples} @@ -225,6 +255,9 @@ template: {question} + + {change_title} + chart: system: | @@ -369,6 +402,8 @@ template: ### 以往提问: {old_questions} + + /no_think analysis: system: | diff --git a/frontend/public/assistant.js b/frontend/public/assistant.js index 270c28b7..a986eaa9 100644 --- a/frontend/public/assistant.js +++ b/frontend/public/assistant.js @@ -539,12 +539,15 @@ return } if (event.data?.busi == 'ready' && event.data?.ready) { - const certificate = parsrCertificate(data) params = { - busi: 'certificate', - certificate, eventName, messageId: id, + hostOrigin: window.location.origin, + } + if (data.type === 1) { + const certificate = parsrCertificate(data) + params['busi'] = 'certificate' + params['certificate'] = certificate } const contentWindow = iframe.contentWindow contentWindow.postMessage(params, url) @@ -596,10 +599,7 @@ tempData['userFlag'] = userFlag tempData['history'] = history initsqlbot_assistant(tempData) - if (data.type == 1) { - registerMessageEvent(id, tempData) - // postMessage the certificate to iframe - } + registerMessageEvent(id, tempData) }) .catch((e) => { showMsg('嵌入失败', e.message) diff --git a/frontend/src/api/chat.ts b/frontend/src/api/chat.ts index affb883a..d840b31e 100644 --- a/frontend/src/api/chat.ts +++ b/frontend/src/api/chat.ts @@ -331,6 +331,9 @@ export const chatApi = { recommendQuestions: (record_id: number | undefined, controller?: AbortController) => { return request.fetchStream(`/chat/recommend_questions/${record_id}`, {}, controller) }, + recentQuestions: (datasource_id?: number): Promise => { + return request.get(`/chat/recent_questions/${datasource_id}`) + }, checkLLMModel: () => request.get('/system/aimodel/default', { requestOptions: { silent: true } }), export2Excel: (record_id: number | undefined) => request.get(`/chat/record/${record_id}/excel/export`, { diff --git a/frontend/src/api/professional.ts b/frontend/src/api/professional.ts index 64408d3a..451ff885 100644 --- a/frontend/src/api/professional.ts +++ b/frontend/src/api/professional.ts @@ -2,10 +2,7 @@ import { request } from '@/utils/request' export const professionalApi = { getList: (pageNum: any, pageSize: any, params: any) => - request.get(`/system/terminology/page/${pageNum}/${pageSize}`, { - params, - }), - + request.get(`/system/terminology/page/${pageNum}/${pageSize}${params}`), updateEmbedded: (data: any) => request.put('/system/terminology', data), deleteEmbedded: (params: any) => request.delete('/system/terminology', { data: params }), getOne: (id: any) => request.get(`/system/terminology/${id}`), diff --git a/frontend/src/api/prompt.ts b/frontend/src/api/prompt.ts index 316b64e3..1d98b2bf 100644 --- a/frontend/src/api/prompt.ts +++ b/frontend/src/api/prompt.ts @@ -2,9 +2,7 @@ import { request } from '@/utils/request' export const promptApi = { getList: (pageNum: any, pageSize: any, type: any, params: any) => - request.get(`/system/custom_prompt/${type}/page/${pageNum}/${pageSize}`, { - params, - }), + request.get(`/system/custom_prompt/${type}/page/${pageNum}/${pageSize}${params}`), updateEmbedded: (data: any) => request.put(`/system/custom_prompt`, data), deleteEmbedded: (params: any) => request.delete('/system/custom_prompt', { data: params }), getOne: (id: any) => request.get(`/system/custom_prompt/${id}`), diff --git a/frontend/src/api/recommendedApi.ts b/frontend/src/api/recommendedApi.ts index 64f83538..7c73b62e 100644 --- a/frontend/src/api/recommendedApi.ts +++ b/frontend/src/api/recommendedApi.ts @@ -3,5 +3,6 @@ import { request } from '@/utils/request' export const recommendedApi = { get_recommended_problem: (dsId: any) => request.get(`/recommended_problem/get_datasource_recommended/${dsId}`), - save_recommended_problem: (data: any) => request.post(`/recommended_problem/save`, data), + save_recommended_problem: (data: any) => + request.post(`/recommended_problem/save_recommended_problem`, data), } diff --git a/frontend/src/api/setting.ts b/frontend/src/api/setting.ts index 456e5b6d..354aac58 100644 --- a/frontend/src/api/setting.ts +++ b/frontend/src/api/setting.ts @@ -17,4 +17,10 @@ export const settingsApi = { requestOptions: { customError: true }, } ), + + downloadTemplate: (url: any) => + request.get(url, { + responseType: 'blob', + requestOptions: { customError: true }, + }), } diff --git a/frontend/src/assets/svg/icon_qr_outlined.svg b/frontend/src/assets/svg/icon_qr_outlined.svg new file mode 100644 index 00000000..8fb87a4f --- /dev/null +++ b/frontend/src/assets/svg/icon_qr_outlined.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/frontend/src/assets/svg/icon_quick_question.svg b/frontend/src/assets/svg/icon_quick_question.svg new file mode 100644 index 00000000..dbfdebce --- /dev/null +++ b/frontend/src/assets/svg/icon_quick_question.svg @@ -0,0 +1 @@ + diff --git a/frontend/src/assets/svg/logo_cas.svg b/frontend/src/assets/svg/logo_cas.svg index 72242460..ad09ca8b 100644 --- a/frontend/src/assets/svg/logo_cas.svg +++ b/frontend/src/assets/svg/logo_cas.svg @@ -1,14 +1,14 @@ - + - + - - - + + + diff --git a/frontend/src/assets/svg/logo_ldap.svg b/frontend/src/assets/svg/logo_ldap.svg index 57deedf2..bd13cd4a 100644 --- a/frontend/src/assets/svg/logo_ldap.svg +++ b/frontend/src/assets/svg/logo_ldap.svg @@ -1,4 +1,4 @@ - + @@ -6,7 +6,7 @@ - + diff --git a/frontend/src/assets/svg/logo_oauth.svg b/frontend/src/assets/svg/logo_oauth.svg index 89ef68c1..aaca69e3 100644 --- a/frontend/src/assets/svg/logo_oauth.svg +++ b/frontend/src/assets/svg/logo_oauth.svg @@ -1,3 +1,3 @@ - - + + diff --git a/frontend/src/assets/svg/logo_oidc.svg b/frontend/src/assets/svg/logo_oidc.svg new file mode 100644 index 00000000..99569140 --- /dev/null +++ b/frontend/src/assets/svg/logo_oidc.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/frontend/src/components/layout/Person.vue b/frontend/src/components/layout/Person.vue index d7972033..9080c127 100644 --- a/frontend/src/components/layout/Person.vue +++ b/frontend/src/components/layout/Person.vue @@ -31,6 +31,7 @@ const name = computed(() => userStore.getName) const account = computed(() => userStore.getAccount) const currentLanguage = computed(() => userStore.getLanguage) const isAdmin = computed(() => userStore.isAdmin) +const isLocalUser = computed(() => !userStore.getOrigin) const dialogVisible = ref(false) const aboutRef = ref() const languageList = computed(() => [ @@ -83,8 +84,9 @@ const savePwdHandler = () => { pwdFormRef.value?.submit() } const logout = async () => { - await userStore.logout() - router.push('/login') + if (!(await userStore.logout())) { + router.push('/login') + } } @@ -118,7 +120,7 @@ const logout = async () => {
{{ $t('common.system_manage') }}
-
+
diff --git a/frontend/src/components/layout/index.vue b/frontend/src/components/layout/index.vue index de8c2ee7..f83ef983 100644 --- a/frontend/src/components/layout/index.vue +++ b/frontend/src/components/layout/index.vue @@ -236,8 +236,9 @@ const menuSelect = (e: any) => { router.push(e.index) } const logout = async () => { - await userStore.logout() - router.push('/login') + if (!(await userStore.logout())) { + router.push('/login') + } } const toSystem = () => { router.push('/system') diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index ca3d1b7a..259afd4b 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -117,6 +117,12 @@ "english": "English", "re_upload": "Re-upload", "not_exceed_50mb": "Supports XLS, XLSX, CSV formats, file size does not exceed 50MB", + "excel_file_type_limit": "Only XLS and XLSX formats are supported", + "click_to_select_file": "Click to select file", + "upload_hint_first": "Please ", + "upload_hint_download_template": "download the template", + "upload_hint_end": " first, then upload after filling it out as required", + "continue_to_upload": "Continue to import", "reset_password": "Reset password", "password_reset_successful": "Password reset successful", "or": "Or", @@ -184,6 +190,12 @@ "chart_selected": "Selected {0}" }, "qa": { + "recommended_repetitive_tips": "Duplicate questions exist", + "retrieve_error": "Model recommendation failed...", + "retrieve_again": "Retrieve Again", + "recently": "recently", + "recommend": "recommend", + "quick_question": "quick question", "new_chat": "New Chat", "start_sqlbot": "New Chat", "title": "Data Q&A", @@ -309,6 +321,7 @@ } }, "datasource": { + "recommended_problem_tips": "Custom configuration requires at least one problem, each problem should be 2-200 characters", "recommended_problem_configuration": "Recommended Problem Configuration", "problem_generation_method": "Problem Generation Method", "ai_automatic_generation": "AI Automatic Generation", @@ -564,7 +577,7 @@ "application_name": "Application name", "application_description": "Application description", "cross_domain_settings": "Cross-domain settings", - "third_party_address": "Please enter the embedded third party address", + "third_party_address": "Please enter the embedded third party address,multiple items separated by semicolons", "set_to_private": "Set as private", "set_to_public": "Set as public", "public": "Public", @@ -572,7 +585,9 @@ "creating_advanced_applications": "Creating Advanced Applications", "configure_interface": "Configure interface", "interface_url": "Interface URL", - "format_is_incorrect": "format is incorrect", + "format_is_incorrect": "format is incorrect{msg}", + "domain_format_incorrect": ", start with http:// or https://, no trailing slash (/), multiple domains separated by semicolons", + "interface_url_incorrect": ",enter a relative path starting with /", "aes_enable": "Enable AES encryption", "aes_enable_tips": "The fields (host, user, password, dataBase, schema) are all encrypted using the AES-CBC-PKCS5Padding encryption method", "bit": "bit", @@ -734,7 +749,22 @@ "revoke_url": "Revocation URL", "oauth2_field_mapping_placeholder": "Example: {'{'}\"account\": \"OAuth2Account\", \"name\": \"OAuth2Name\", \"email\": \"email\"{'}'}", "token_auth_method": "Token auth method", - "userinfo_auth_method": "Userinfo auth method" + "userinfo_auth_method": "Userinfo auth method", + "oidc_settings": "OIDC Settings", + "metadata_url": "Metadata URL", + "realm": "Realm", + "oidc_field_mapping_placeholder": "e.g., {\"account\": \"oidcAccount\", \"name\": \"oidcName\", \"email\": \"email\"}", + "ldap_settings": "LDAP Settings", + "server_address": "Server Address", + "server_address_placeholder": "Example: ldap://ldap.example.com:389", + "bind_dn": "Bind DN", + "bind_dn_placeholder": "Example: cn=admin,dc=example,dc=com", + "bind_pwd": "Bind Password", + "ou": "User OU", + "ou_placeholder": "Example: ou=users,dc=example,dc=com", + "user_filter": "User Filter", + "user_filter_placeholder": "Example: uid", + "ldap_field_mapping_placeholder": "Example: {\"account\": \"ldapAccount\", \"name\": \"ldapName\", \"email\": \"mail\"}" }, "login": { "default_login": "Default", @@ -746,7 +776,8 @@ "qr_code": "QR Code", "platform_disable": "{0} settings are not enabled!", "input_account": "Please enter account", - "redirect_2_auth": "Redirecting to {0} authentication, {1} seconds..." + "redirect_2_auth": "Redirecting to {0} authentication, {1} seconds...", + "redirect_immediately": "Redirecting immediately" }, "supplier": { "alibaba_cloud_bailian": "Alibaba Cloud Bailian", diff --git a/frontend/src/i18n/ko-KR.json b/frontend/src/i18n/ko-KR.json index a9bf5820..a79aa92e 100644 --- a/frontend/src/i18n/ko-KR.json +++ b/frontend/src/i18n/ko-KR.json @@ -117,6 +117,12 @@ "english": "English", "re_upload": "다시 업로드", "not_exceed_50mb": "XLS, XLSX, CSV 형식을 지원하며, 파일 크기는 50MB를 초과할 수 없습니다", + "excel_file_type_limit": "XLS 및 XLSX 형식만 지원됩니다", + "click_to_select_file": "파일 선택을 클릭하세요", + "upload_hint_first": "먼저 ", + "upload_hint_download_template": "템플릿을 다운로드", + "upload_hint_end": ", 한 후 요구사항에 따라 작성하여 업로드하세요", + "continue_to_upload": "계속 가져오기", "reset_password": "비밀번호 재설정", "password_reset_successful": "비밀번호 재설정 성공", "or": "또는", @@ -184,6 +190,12 @@ "chart_selected": "{0}개 선택됨" }, "qa": { + "recommended_repetitive_tips": "중복된 문제가 존재합니다", + "retrieve_error": "모델 추천 문제 실패...", + "retrieve_again": "다시 가져오기", + "recently": "최근", + "recommend": "추천", + "quick_question": "빠른 질문", "new_chat": "새 대화 생성", "start_sqlbot": "데이터 조회 시작", "title": "스마트 데이터 조회", @@ -309,6 +321,7 @@ } }, "datasource": { + "recommended_problem_tips": "사용자 정의 구성으로 최소 한 개의 문제를 생성하세요, 각 문제는 2~200자로 작성", "recommended_problem_configuration": "추천 문제 구성", "problem_generation_method": "문제 생성 방식", "ai_automatic_generation": "AI 자동 생성", @@ -564,14 +577,16 @@ "application_name": "애플리케이션 이름", "application_description": "애플리케이션 설명", "cross_domain_settings": "교차 도메인 설정", - "third_party_address": "임베디드할 제3자 주소를 입력하십시오", + "third_party_address": "임베디드할 제3자 주소를 입력하십시오, 여러 항목을 세미콜론으로 구분", "set_to_private": "비공개로 설정", "set_to_public": "공개로 설정", "public": "공개", "private": "비공개", "configure_interface": "인터페이스 설정", "interface_url": "인터페이스 URL", - "format_is_incorrect": "형식이 올바르지 않습니다", + "format_is_incorrect": "형식이 올바르지 않습니다{msg}", + "domain_format_incorrect": ", http:// 또는 https://로 시작해야 하며, 슬래시(/)로 끝날 수 없습니다. 여러 도메인은 세미콜론으로 구분합니다", + "interface_url_incorrect": ", 상대 경로를 입력해주세요. /로 시작합니다", "aes_enable": "AES 암호화 활성화", "aes_enable_tips": "암호화 필드 (host, user, password, dataBase, schema)는 모두 AES-CBC-PKCS5Padding 암호화 방식을 사용합니다", "bit": "비트", @@ -734,7 +749,22 @@ "revoke_url": "취소 URL", "oauth2_field_mapping_placeholder": "예: {'{'}\"account\": \"OAuth2Account\", \"name\": \"OAuth2Name\", \"email\": \"email\"{'}'}", "token_auth_method": "토큰 인증 방식", - "userinfo_auth_method": "사용자 정보 인증 방식" + "userinfo_auth_method": "사용자 정보 인증 방식", + "oidc_settings": "OIDC 설정", + "metadata_url": "메타데이터 URL", + "realm": "영역", + "oidc_field_mapping_placeholder": "예: {\"account\": \"oidcAccount\", \"name\": \"oidcName\", \"email\": \"email\"}", + "ldap_settings": "LDAP 설정", + "server_address": "서버 주소", + "server_address_placeholder": "예: ldap://ldap.example.com:389", + "bind_dn": "바인드 DN", + "bind_dn_placeholder": "예: cn=admin,dc=example,dc=com", + "bind_pwd": "바인드 비밀번호", + "ou": "사용자 OU", + "ou_placeholder": "예: ou=users,dc=example,dc=com", + "user_filter": "사용자 필터", + "user_filter_placeholder": "예: uid", + "ldap_field_mapping_placeholder": "예: {\"account\": \"ldapAccount\", \"name\": \"ldapName\", \"email\": \"mail\"}" }, "login": { "default_login": "기본값", @@ -746,7 +776,8 @@ "qr_code": "QR 코드", "platform_disable": "{0} 설정이 활성화되지 않았습니다!", "input_account": "계정을 입력해 주세요", - "redirect_2_auth": "{0} 인증으로 리디렉션 중입니다, {1}초..." + "redirect_2_auth": "{0} 인증으로 리디렉션 중입니다, {1}초...", + "redirect_immediately": "지금 이동" }, "supplier": { "alibaba_cloud_bailian": "알리바바 클라우드 바이리엔", diff --git a/frontend/src/i18n/zh-CN.json b/frontend/src/i18n/zh-CN.json index 0f9d6e2b..65917ec8 100644 --- a/frontend/src/i18n/zh-CN.json +++ b/frontend/src/i18n/zh-CN.json @@ -117,6 +117,12 @@ "english": "English", "re_upload": "重新上传", "not_exceed_50mb": "支持 XLS、XLSX、CSV 格式,文件大小不超过 50MB", + "excel_file_type_limit": "仅支持 XLS、XLSX 格式", + "click_to_select_file": "点击选择文件", + "upload_hint_first": "先", + "upload_hint_download_template": "下载模板", + "upload_hint_end": ",按要求填写后上传", + "continue_to_upload": "继续导入", "reset_password": "重置密码", "password_reset_successful": "重置密码成功", "or": "或者", @@ -184,6 +190,11 @@ "chart_selected": "已选{0}" }, "qa": { + "retrieve_error": "模型推荐问题失败...", + "retrieve_again": "重新获取", + "recently": "最近", + "recommend": "推荐", + "quick_question": "快捷提问", "new_chat": "新建对话", "start_sqlbot": "开启问数", "title": "智能问数", @@ -309,6 +320,8 @@ } }, "datasource": { + "recommended_repetitive_tips": "存在重复问题", + "recommended_problem_tips": "自定义配置至少一个问题,每个问题2-200个字符", "recommended_problem_configuration": "推荐问题配置", "problem_generation_method": "问题生成方式", "ai_automatic_generation": "AI 自动生成", @@ -564,14 +577,16 @@ "application_name": "应用名称", "application_description": "应用描述", "cross_domain_settings": "跨域设置", - "third_party_address": "请输入嵌入的第三方地址", + "third_party_address": "请输入嵌入的第三方地址,多个以分号分割", "set_to_private": "设为私有", "set_to_public": "设为公共", "public": "公共", "private": "私有", "configure_interface": "配置接口", "interface_url": "接口 URL", - "format_is_incorrect": "格式不对", + "format_is_incorrect": "格式不对{msg}", + "domain_format_incorrect": ",http或https开头,不能以 / 结尾,多个域名以分号(半角)分隔", + "interface_url_incorrect": ",请填写相对路径,以/开头", "aes_enable": "开启 AES 加密", "aes_enable_tips": "加密字段 (host, user, password, dataBase, schema) 均采用 AES-CBC-PKCS5Padding 加密方式", "bit": "位", @@ -734,7 +749,22 @@ "revoke_url": "撤销地址", "oauth2_field_mapping_placeholder": "例如:{'{'}\"account\": \"oauth2Account\", \"name\": \"oauth2Name\", \"email\": \"email\"{'}'}", "token_auth_method": "Token 认证方式", - "userinfo_auth_method": "用户信息认证方式" + "userinfo_auth_method": "用户信息认证方式", + "oidc_settings": "OIDC 设置", + "metadata_url": "元数据地址", + "realm": "领域", + "oidc_field_mapping_placeholder": "例如:{'{'}\"account\": \"oidcAccount\", \"name\": \"oidcName\", \"email\": \"email\"{'}'}", + "ldap_settings": "LDAP 设置", + "server_address": "服务器地址", + "server_address_placeholder": "例如:ldap://ldap.example.com:389", + "bind_dn": "绑定 DN", + "bind_dn_placeholder": "例如:cn=admin,dc=example,dc=com", + "bind_pwd": "绑定密码", + "ou": "用户 OU", + "ou_placeholder": "例如:ou=users,dc=example,dc=com", + "user_filter": "用户过滤器", + "user_filter_placeholder": "例如:uid", + "ldap_field_mapping_placeholder": "例如:{'{'}\"account\": \"ldapAccount\", \"name\": \"ldapName\", \"email\": \"mail\"{'}'}" }, "login": { "default_login": "默认", @@ -746,7 +776,8 @@ "qr_code": "二维码", "platform_disable": "{0}设置未开启!", "input_account": "请输入账号", - "redirect_2_auth": "正在跳转至 {0} 认证,{1} 秒..." + "redirect_2_auth": "正在跳转至 {0} 认证,{1} 秒...", + "redirect_immediately": "立即跳转" }, "supplier": { "alibaba_cloud_bailian": "阿里云百炼", diff --git a/frontend/src/stores/appearance.ts b/frontend/src/stores/appearance.ts index 217f5d35..2093f300 100644 --- a/frontend/src/stores/appearance.ts +++ b/frontend/src/stores/appearance.ts @@ -8,7 +8,6 @@ import { setTitle, setCurrentColor } from '@/utils/utils' const basePath = import.meta.env.VITE_API_BASE_URL const baseUrl = basePath + '/system/appearance/picture/' import { isBtnShow } from '@/utils/utils' -import type { LinkHTMLAttributes } from 'vue' interface AppearanceState { themeColor?: string customColor?: string @@ -68,8 +67,8 @@ export const useAppearanceStore = defineStore('appearanceStore', { showDemoTips: false, demoTipsContent: '', fontList: [], - pc_welcome: '', - pc_welcome_desc: '', + pc_welcome: undefined, + pc_welcome_desc: undefined, } }, getters: { @@ -311,7 +310,7 @@ export const useAppearanceStore = defineStore('appearanceStore', { }) const setLinkIcon = (linkWeb?: string) => { - const link = document.querySelector('link[rel="icon"]') as LinkHTMLAttributes + const link = document.querySelector('link[rel="icon"]') as HTMLLinkElement if (link) { if (linkWeb) { link['href'] = baseUrl + linkWeb diff --git a/frontend/src/stores/assistant.ts b/frontend/src/stores/assistant.ts index 5b27fec4..46b7a789 100644 --- a/frontend/src/stores/assistant.ts +++ b/frontend/src/stores/assistant.ts @@ -21,6 +21,7 @@ interface AssistantState { online: boolean pageEmbedded?: boolean history: boolean + hostOrigin: string requestPromiseMap: Map } @@ -36,6 +37,7 @@ export const AssistantStore = defineStore('assistant', { online: false, pageEmbedded: false, history: true, + hostOrigin: '', requestPromiseMap: new Map(), } }, @@ -70,6 +72,9 @@ export const AssistantStore = defineStore('assistant', { getEmbedded(): boolean { return this.assistant && this.type === 4 }, + getHostOrigin(): string { + return this.hostOrigin + }, }, actions: { refreshCertificate() { @@ -138,6 +143,9 @@ export const AssistantStore = defineStore('assistant', { setHistory(history: boolean) { this.history = history ?? true }, + setHostOrigin(origin: string) { + this.hostOrigin = origin + }, async setChat() { if (!this.assistant) { return null diff --git a/frontend/src/stores/user.ts b/frontend/src/stores/user.ts index d3bc3cd4..330431af 100644 --- a/frontend/src/stores/user.ts +++ b/frontend/src/stores/user.ts @@ -18,6 +18,7 @@ interface UserState { exp: number time: number weight: number + origin: number platformInfo: any | null [key: string]: string | number | any | null } @@ -34,6 +35,7 @@ export const UserStore = defineStore('user', { exp: 0, time: 0, weight: 0, + origin: 0, platformInfo: null, } }, @@ -68,6 +70,9 @@ export const UserStore = defineStore('user', { getWeight(): number { return this.weight }, + getOrigin(): number { + return this.origin + }, isSpaceAdmin(): boolean { return this.uid === '1' || !!this.weight }, @@ -91,24 +96,37 @@ export const UserStore = defineStore('user', { if (res) { window.location.href = res window.open(res, '_self') + return res } if (getQueryString('code') && getQueryString('state')?.includes('oauth2_state')) { const logout_url = location.origin + location.pathname + '#/login' window.location.href = logout_url window.open(res, logout_url) + return logout_url } + return null }, async info() { const res: any = await AuthApi.info() const res_data = res || {} - const keys = ['uid', 'account', 'name', 'oid', 'language', 'exp', 'time', 'weight'] as const + const keys = [ + 'uid', + 'account', + 'name', + 'oid', + 'language', + 'exp', + 'time', + 'weight', + 'origin', + ] as const keys.forEach((key) => { const dkey = key === 'uid' ? 'id' : key const value = res_data[dkey] - if (key === 'exp' || key === 'time' || key === 'weight') { + if (key === 'exp' || key === 'time' || key === 'weight' || key === 'origin') { this[key] = Number(value) } else { this[key] = String(value) @@ -165,6 +183,10 @@ export const UserStore = defineStore('user', { wsCache.set('user.weight', weight) this.weight = weight }, + setOrigin(origin: number) { + wsCache.set('user.origin', origin) + this.origin = origin + }, setPlatformInfo(info: any | null) { wsCache.set('user.platformInfo', info) this.platformInfo = info @@ -180,6 +202,7 @@ export const UserStore = defineStore('user', { 'exp', 'time', 'weight', + 'origin', 'platformInfo', ] keys.forEach((key) => wsCache.delete('user.' + key)) diff --git a/frontend/src/utils/request.ts b/frontend/src/utils/request.ts index cbf22931..ca366dfc 100644 --- a/frontend/src/utils/request.ts +++ b/frontend/src/utils/request.ts @@ -100,6 +100,9 @@ class HttpService { if (!assistantStore.getType || assistantStore.getType === 2) { config.headers['X-SQLBOT-ASSISTANT-ONLINE'] = assistantStore.getOnline } + if (assistantStore.getHostOrigin) { + config.headers['X-SQLBOT-HOST-ORIGIN'] = assistantStore.getHostOrigin + } } const locale = getLocale() if (locale) { @@ -302,6 +305,9 @@ class HttpService { encodeURIComponent(assistantStore.getCertificate) ) } + if (assistantStore.getHostOrigin) { + heads['X-SQLBOT-HOST-ORIGIN'] = assistantStore.getHostOrigin + } if (!assistantStore.getType || assistantStore.getType === 2) { heads['X-SQLBOT-ASSISTANT-ONLINE'] = assistantStore.getOnline } diff --git a/frontend/src/views/WelcomeView.vue b/frontend/src/views/WelcomeView.vue index fb382c86..e8813fd2 100644 --- a/frontend/src/views/WelcomeView.vue +++ b/frontend/src/views/WelcomeView.vue @@ -18,8 +18,9 @@ const router = useRouter() const userStore = useUserStore() const logout = async () => { - await userStore.logout() - router.push('/login') + if (!(await userStore.logout())) { + router.push('/login') + } } diff --git a/frontend/src/views/chat/QuickQuestion.vue b/frontend/src/views/chat/QuickQuestion.vue new file mode 100644 index 00000000..80f5b447 --- /dev/null +++ b/frontend/src/views/chat/QuickQuestion.vue @@ -0,0 +1,187 @@ + + + + + diff --git a/frontend/src/views/chat/RecentQuestion.vue b/frontend/src/views/chat/RecentQuestion.vue new file mode 100644 index 00000000..6ba7d539 --- /dev/null +++ b/frontend/src/views/chat/RecentQuestion.vue @@ -0,0 +1,102 @@ + + + + + diff --git a/frontend/src/views/chat/RecommendQuestion.vue b/frontend/src/views/chat/RecommendQuestion.vue index 5acb9012..c7ff9385 100644 --- a/frontend/src/views/chat/RecommendQuestion.vue +++ b/frontend/src/views/chat/RecommendQuestion.vue @@ -11,6 +11,7 @@ const props = withDefaults( questions?: string firstChat?: boolean disabled?: boolean + position?: string }>(), { recordId: undefined, @@ -18,6 +19,7 @@ const props = withDefaults( questions: '[]', firstChat: false, disabled: false, + position: 'chat', } ) @@ -153,11 +155,25 @@ defineExpose({ getRecommendQuestions, id: () => props.recordId, stop }) diff --git a/frontend/src/views/chat/component/charts/Table.ts b/frontend/src/views/chat/component/charts/Table.ts index 81b50f74..b1a3b1bb 100644 --- a/frontend/src/views/chat/component/charts/Table.ts +++ b/frontend/src/views/chat/component/charts/Table.ts @@ -1,6 +1,16 @@ import { BaseChart, type ChartAxis, type ChartData } from '@/views/chat/component/BaseChart.ts' -import { TableSheet, type S2Options, type S2DataConfig, type S2MountContainer } from '@antv/s2' +import { + TableSheet, + S2Event, + copyToClipboard, + type S2Options, + type S2DataConfig, + type S2MountContainer, +} from '@antv/s2' import { debounce } from 'lodash-es' +import { i18n } from '@/i18n' + +const { t } = i18n.global export class Table extends BaseChart { table?: TableSheet = undefined @@ -63,6 +73,15 @@ export class Table extends BaseChart { if (this.container) { this.table = new TableSheet(this.container, s2DataConfig, s2Options) + // right click + this.table.on(S2Event.GLOBAL_COPIED, (data) => { + ElMessage.success(t('qa.copied')) + console.debug('copied: ', data) + }) + this.table.getCanvasElement().addEventListener('contextmenu', (event) => { + event.preventDefault() + }) + this.table.on(S2Event.GLOBAL_CONTEXT_MENU, (event) => copyData(event, this.table)) } } @@ -75,3 +94,51 @@ export class Table extends BaseChart { this.resizeObserver?.disconnect() } } + +function copyData(event: any, s2?: TableSheet) { + event.preventDefault() + if (!s2) { + return + } + const cells = s2.interaction.getCells() + + if (cells.length == 0) { + return + } else if (cells.length == 1) { + const c = cells[0] + const cellMeta = s2.facet.getCellMeta(c.rowIndex, c.colIndex) + if (cellMeta) { + copyToClipboard(cellMeta.fieldValue as string).finally(() => { + ElMessage.success(t('qa.copied')) + console.debug('copied:', cellMeta.fieldValue) + }) + } + return + } else { + let currentRowIndex = -1 + let currentRowData: Array = [] + const rowData: Array = [] + for (let i = 0; i < cells.length; i++) { + const c = cells[i] + const cellMeta = s2.facet.getCellMeta(c.rowIndex, c.colIndex) + if (!cellMeta) { + continue + } + if (currentRowIndex == -1) { + currentRowIndex = c.rowIndex + } + if (c.rowIndex !== currentRowIndex) { + rowData.push(currentRowData.join('\t')) + currentRowData = [] + currentRowIndex = c.rowIndex + } + currentRowData.push(cellMeta.fieldValue as string) + } + rowData.push(currentRowData.join('\t')) + const finalValue = rowData.join('\n') + copyToClipboard(finalValue).finally(() => { + ElMessage.success(t('qa.copied')) + console.debug('copied:\n', finalValue) + }) + } +} diff --git a/frontend/src/views/chat/index.vue b/frontend/src/views/chat/index.vue index 52c1bf2e..4133780e 100644 --- a/frontend/src/views/chat/index.vue +++ b/frontend/src/views/chat/index.vue @@ -145,10 +145,10 @@ > - {{ appearanceStore.pc_welcome }} + {{ appearanceStore.pc_welcome ?? t('qa.greeting') }}
- {{ appearanceStore.pc_welcome_desc }} + {{ appearanceStore.pc_welcome_desc ?? t('qa.hint_description') }}
@@ -221,18 +221,18 @@ :msg="message" :hide-avatar="message.first_chat" > - + + + + + + + + + + + +
+
+ +
{ + quickQuestionRef.value.getRecommendQuestions() + }) } } -const recommendQuestionRef = ref() - function getRecommendQuestions(id?: number) { nextTick(() => { if (recommendQuestionRef.value) { @@ -1178,6 +1196,31 @@ onMounted(() => { } } + .quick_question { + width: calc(100% - 2px); + position: absolute; + margin-left: 1px; + margin-top: 1px; + left: 0; + bottom: 0; + padding-bottom: 12px; + padding-left: 12px; + z-index: 10; + background: transparent; + line-height: 22px; + font-size: 14px; + font-weight: 400; + border-top-right-radius: 16px; + border-top-left-radius: 16px; + color: rgba(100, 106, 115, 1); + display: flex; + align-items: center; + + .name { + color: rgba(31, 35, 41, 1); + } + } + .input-area { border-color: #d9dcdf; diff --git a/frontend/src/views/ds/Card.vue b/frontend/src/views/ds/Card.vue index 1a9759e8..a30da180 100644 --- a/frontend/src/views/ds/Card.vue +++ b/frontend/src/views/ds/Card.vue @@ -136,7 +136,7 @@ const onClickOutside = () => { {{ $t('datasource.edit') }} -