From 4d53a02e0ba53ec755002d617fb38445cb36f4a6 Mon Sep 17 00:00:00 2001 From: ulleo Date: Wed, 19 Nov 2025 15:58:56 +0800 Subject: [PATCH] feat: import Sample SQL / Terminologies --- .../apps/data_training/curd/data_training.py | 28 ++-- backend/apps/terminology/curd/terminology.py | 74 +++++----- backend/locales/zh-CN.json | 7 +- backend/pyproject.toml | 2 +- frontend/src/views/system/prompt/index.vue | 127 +++++++++++++++++- 5 files changed, 191 insertions(+), 47 deletions(-) diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py index bc8ddc22..20ea3fb6 100644 --- a/backend/apps/data_training/curd/data_training.py +++ b/backend/apps/data_training/curd/data_training.py @@ -145,9 +145,11 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid: return _list -def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): +def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans, skip_embedding: bool = False): """ 创建单个数据训练记录 + Args: + skip_embedding: 是否跳过embedding处理(用于批量插入) """ # 基本验证 if not info.question or not info.question.strip(): @@ -203,8 +205,9 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans session.refresh(data_training) session.commit() - # 处理embedding - run_save_data_training_embeddings([data_training.id]) + # 处理embedding(批量插入时跳过) + if not skip_embedding: + run_save_data_training_embeddings([data_training.id]) return data_training.id @@ -247,11 +250,11 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans raise Exception(trans("i18n_data_training.exists_in_db")) stmt = update(DataTraining).where(and_(DataTraining.id == info.id)).values( - question=info.question, - description=info.description, + question=info.question.strip(), + description=info.description.strip(), datasource=info.datasource, - enabled=info.enabled, advanced_application=info.advanced_application, + enabled=info.enabled if info.enabled is not None else True ) session.execute(stmt) session.commit() @@ -374,8 +377,8 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] if valid_records: for info in valid_records: try: - # 直接复用create_training方法 - training_id = create_training(session, info, oid, trans) + # 直接复用create_training方法,跳过embedding处理 + training_id = create_training(session, info, oid, trans, skip_embedding=True) inserted_ids.append(training_id) success_count += 1 @@ -387,6 +390,15 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] 'errors': [str(e)] }) + # 批量处理embedding(只在最后执行一次) + if success_count > 0 and inserted_ids: + try: + run_save_data_training_embeddings(inserted_ids) + except Exception as e: + # 如果embedding处理失败,记录错误但不回滚数据 + print(f"Embedding processing failed: {str(e)}") + # 可以选择将embedding失败的信息记录到日志或返回给调用方 + return { 'success_count': success_count, 'failed_records': failed_records, diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index cb3dd230..282041fc 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -200,9 +200,12 @@ def get_all_terminology(session: SessionDep, name: Optional[str] = None, oid: Op return _list -def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): +def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans, + skip_embedding: bool = False): """ 创建单个术语记录 + Args: + skip_embedding: 是否跳过embedding处理(用于批量插入) """ # 基本验证 if not info.word or not info.word.strip(): @@ -221,16 +224,16 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra raise Exception(trans("i18n_terminology.datasource_cannot_be_none")) parent = Terminology( - word=info.word, + word=info.word.strip(), create_time=create_time, - description=info.description, + description=info.description.strip(), oid=oid, specific_ds=specific_ds, enabled=info.enabled, datasource_ids=datasource_ids ) - words = [info.word] + words = [info.word.strip()] for child_word in info.other_words: # 先检查是否为空字符串 if not child_word or child_word.strip() == "": @@ -239,7 +242,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if child_word in words: raise Exception(trans("i18n_terminology.cannot_be_repeated")) else: - words.append(child_word) + words.append(child_word.strip()) # 基础查询条件(word 和 oid 必须满足) base_query = and_( @@ -288,7 +291,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra child_list.append( Terminology( pid=parent.id, - word=other_word, + word=other_word.strip(), create_time=create_time, oid=oid, enabled=info.enabled, @@ -303,8 +306,9 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra session.commit() - # 处理embedding - run_save_terminology_embeddings([parent.id]) + # 处理embedding(批量插入时跳过) + if not skip_embedding: + run_save_terminology_embeddings([parent.id]) return parent.id @@ -380,19 +384,9 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf # 基本验证 if not info.word or not info.word.strip(): error_messages.append(trans("i18n_terminology.word_cannot_be_empty")) - failed_records.append({ - 'data': info, - 'errors': error_messages - }) - continue if not info.description or not info.description.strip(): error_messages.append(trans("i18n_terminology.description_cannot_be_empty")) - failed_records.append({ - 'data': info, - 'errors': error_messages - }) - continue # 根据specific_ds决定是否验证数据源 specific_ds = info.specific_ds if info.specific_ds is not None else False @@ -455,8 +449,8 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf if valid_records: for info in valid_records: try: - # 直接复用create_terminology方法 - terminology_id = create_terminology(session, info, oid, trans) + # 直接复用create_terminology方法,跳过embedding处理 + terminology_id = create_terminology(session, info, oid, trans, skip_embedding=True) inserted_ids.append(terminology_id) success_count += 1 @@ -468,6 +462,15 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf 'errors': [str(e)] }) + # 批量处理embedding(只在最后执行一次) + if success_count > 0 and inserted_ids: + try: + run_save_terminology_embeddings(inserted_ids) + except Exception as e: + # 如果embedding处理失败,记录错误但不回滚数据 + print(f"Terminology embedding processing failed: {str(e)}") + # 可以选择将embedding失败的信息记录到日志或返回给调用方 + return { 'success_count': success_count, 'failed_records': failed_records, @@ -492,12 +495,12 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if not datasource_ids: raise Exception(trans("i18n_terminology.datasource_cannot_be_none")) - words = [info.word] + words = [info.word.strip()] for child in info.other_words: if child in words: raise Exception(trans("i18n_terminology.cannot_be_repeated")) else: - words.append(child) + words.append(child.strip()) # 基础查询条件(word 和 oid 必须满足) base_query = and_( @@ -539,8 +542,8 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra raise Exception(trans("i18n_terminology.exists_in_db")) stmt = update(Terminology).where(and_(Terminology.id == info.id)).values( - word=info.word, - description=info.description, + word=info.word.strip(), + description=info.description.strip(), specific_ds=specific_ds, datasource_ids=datasource_ids, enabled=info.enabled, @@ -553,16 +556,27 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra session.commit() create_time = datetime.datetime.now() - _list: List[Terminology] = [] + # 插入子记录(其他词) + child_list = [] if info.other_words: for other_word in info.other_words: if other_word.strip() == "": continue - _list.append( - Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid, - specific_ds=specific_ds, datasource_ids=datasource_ids, enabled=info.enabled)) - session.bulk_save_objects(_list) - session.flush() + child_list.append( + Terminology( + pid=info.id, + word=other_word.strip(), + create_time=create_time, + oid=oid, + enabled=info.enabled, + specific_ds=specific_ds, + datasource_ids=datasource_ids + ) + ) + + if child_list: + session.bulk_save_objects(child_list) + session.flush() session.commit() # embedding diff --git a/backend/locales/zh-CN.json b/backend/locales/zh-CN.json index 2e136e20..ad0bfffe 100644 --- a/backend/locales/zh-CN.json +++ b/backend/locales/zh-CN.json @@ -74,7 +74,12 @@ "prompt_word_name": "提示词名称", "prompt_word_content": "提示词内容", "effective_data_sources": "生效数据源", - "all_data_sources": "所有数据源" + "all_data_sources": "所有数据源", + "name_cannot_be_empty": "名称不能为空", + "prompt_cannot_be_empty": "提示词内容不能为空", + "type_cannot_be_empty": "类型不能为空", + "datasource_not_found": "找不到数据源", + "datasource_cannot_be_none": "数据源不能为空", }, "i18n_excel_export": { "data_is_empty": "表单数据为空,无法导出数据" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4c77ea01..913b26df 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "pyyaml (>=6.0.2,<7.0.0)", "fastapi-mcp (>=0.3.4,<0.4.0)", "tabulate>=0.9.0", - "sqlbot-xpack>=0.0.3.45,<1.0.0", + "sqlbot-xpack>=0.0.3.46,<1.0.0", "fastapi-cache2>=0.2.2", "sqlparse>=0.5.3", "redis>=6.2.0", diff --git a/frontend/src/views/system/prompt/index.vue b/frontend/src/views/system/prompt/index.vue index 1697f2dd..91379b60 100644 --- a/frontend/src/views/system/prompt/index.vue +++ b/frontend/src/views/system/prompt/index.vue @@ -14,6 +14,9 @@ import EmptyBackground from '@/views/dashboard/common/EmptyBackground.vue' import { useClipboard } from '@vueuse/core' import { useI18n } from 'vue-i18n' import { cloneDeep } from 'lodash-es' +import { genFileId, type UploadInstance, type UploadProps, type UploadRawFile } from 'element-plus' +import { settingsApi } from '@/api/setting.ts' +import { useCache } from '@/utils/useCache.ts' interface Form { id?: string | null @@ -80,6 +83,95 @@ const cancelDelete = () => { isIndeterminate.value = false } +const uploadRef = ref() +const uploadLoading = ref(false) + +const { wsCache } = useCache() +const token = wsCache.get('user.token') +const headers = ref({ 'X-SQLBOT-TOKEN': `Bearer ${token}` }) +const getUploadURL = (type: string) => { + return import.meta.env.VITE_API_BASE_URL + `/system/custom_prompt/${type}/uploadExcel` +} + +const handleExceed: UploadProps['onExceed'] = (files) => { + uploadRef.value!.clearFiles() + const file = files[0] as UploadRawFile + file.uid = genFileId() + uploadRef.value!.handleStart(file) +} + +const beforeUpload = (rawFile: any) => { + if (rawFile.size / 1024 / 1024 > 50) { + ElMessage.error(t('common.not_exceed_50mb')) + return false + } + uploadLoading.value = true + return true +} +const onSuccess = (response: any) => { + uploadRef.value!.clearFiles() + search() + + if (response?.data?.failed_count > 0 && response?.data?.error_excel_filename) { + ElMessage.error( + t('training.upload_failed', { + success: response.data.success_count, + fail: response.data.failed_count, + fail_info: response.data.error_excel_filename, + }) + ) + settingsApi + .downloadError(response.data.error_excel_filename) + .then((res) => { + const blob = new Blob([res], { + type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + }) + const link = document.createElement('a') + link.href = URL.createObjectURL(blob) + link.download = response.data.error_excel_filename + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + }) + .catch(async (error) => { + if (error.response) { + try { + let text = await error.response.data.text() + try { + text = JSON.parse(text) + } finally { + ElMessage({ + message: text, + type: 'error', + showClose: true, + }) + } + } catch (e) { + console.error('Error processing error response:', e) + } + } else { + console.error('Other error:', error) + ElMessage({ + message: error, + type: 'error', + showClose: true, + }) + } + }) + .finally(() => { + uploadLoading.value = false + }) + } else { + ElMessage.success(t('training.upload_success')) + uploadLoading.value = false + } +} + +const onError = () => { + uploadLoading.value = false + uploadRef.value!.clearFiles() +} + const exportExcel = () => { let title = '' if (currentType.value === 'GENERATE_SQL') { @@ -383,7 +475,7 @@ const typeChange = (val: any) => { {{ $t('prompt.data_prediction') }} -
+
{ {{ $t('professional.export_all') }} - - - {{ $t('user.batch_import') }} - + + + + {{ $t('user.batch_import') }} + +