diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py index 071dcdd1..9d423517 100644 --- a/backend/apps/data_training/api/data_training.py +++ b/backend/apps/data_training/api/data_training.py @@ -1,15 +1,20 @@ import asyncio +import hashlib import io +import os +import uuid +from http.client import HTTPException from typing import Optional import pandas as pd -from fastapi import APIRouter, Query -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, File, UploadFile, Query +from fastapi.responses import StreamingResponse, FileResponse from apps.chat.models.chat_model import AxisObj from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \ - enable_training, get_all_data_training + enable_training, get_all_data_training, batch_create_training from apps.data_training.models.data_training_model import DataTrainingInfo +from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans from common.utils.data_format import DataFormat @@ -90,3 +95,146 @@ def inner(): result = await asyncio.to_thread(inner) return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") + + +path = settings.EXCEL_PATH + +from sqlalchemy.orm import sessionmaker, scoped_session +from common.core.db import engine +from sqlmodel import Session + +session_maker = scoped_session(sessionmaker(bind=engine, class_=Session)) + + +@router.post("/uploadExcel") +async def upload_excel(trans: Trans, current_user: CurrentUser, file: UploadFile = File(...)): + ALLOWED_EXTENSIONS = {"xlsx", "xls"} + if not file.filename.lower().endswith(tuple(ALLOWED_EXTENSIONS)): + raise HTTPException(400, "Only support .xlsx/.xls") + + os.makedirs(path, exist_ok=True) + base_filename = f"{file.filename.split('.')[0]}_{hashlib.sha256(uuid.uuid4().bytes).hexdigest()[:10]}" + filename = f"{base_filename}.{file.filename.split('.')[1]}" + save_path = os.path.join(path, filename) + with open(save_path, "wb") as f: + f.write(await file.read()) + + oid = current_user.oid + + use_cols = [0, 1, 2] # 问题, 描述, 数据源名称 + # 根据oid确定要读取的列 + if oid == 1: + use_cols = [0, 1, 2, 3] # 问题, 描述, 数据源名称, 高级应用名称 + + def inner(): + + session = session_maker() + + sheet_names = pd.ExcelFile(save_path).sheet_names + + import_data = [] + + for sheet_name in sheet_names: + + df = pd.read_excel( + save_path, + sheet_name=sheet_name, + engine='calamine', + header=0, + usecols=use_cols, + dtype=str + ).fillna("") + + for index, row in df.iterrows(): + # 跳过空行 + if row.isnull().all(): + continue + + question = row[0].strip() if pd.notna(row[0]) and row[0].strip() else None + description = row[1].strip() if pd.notna(row[1]) and row[1].strip() else None + datasource_name = row[2].strip() if pd.notna(row[2]) and row[2].strip() else None + + advanced_application_name = None + if oid == 1 and len(row) > 3: + advanced_application_name = row[3].strip() if pd.notna(row[3]) and row[3].strip() else None + + if oid == 1: + import_data.append( + DataTrainingInfo(oid=oid, question=question, description=description, + datasource_name=datasource_name, + advanced_application_name=advanced_application_name)) + else: + import_data.append( + DataTrainingInfo(oid=oid, question=question, description=description, + datasource_name=datasource_name)) + + res = batch_create_training(session, import_data, oid, trans) + + failed_records = res['failed_records'] + + error_excel_filename = None + + if len(failed_records) > 0: + data_list = [] + for obj in failed_records: + _data = { + "question": obj['data'].question, + "description": obj['data'].description, + "datasource_name": obj['data'].datasource_name, + "advanced_application_name": obj['data'].advanced_application_name, + "errors": obj['errors'] + } + data_list.append(_data) + + fields = [] + fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='question')) + fields.append(AxisObj(name=trans('i18n_data_training.sample_sql'), value='description')) + fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources'), value='datasource_name')) + if current_user.oid == 1: + fields.append( + AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name')) + fields.append(AxisObj(name=trans('i18n_data_training.error_info'), value='errors')) + + md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list) + + df = pd.DataFrame(md_data, columns=_fields_list) + error_excel_filename = f"{base_filename}_error.xlsx" + save_error_path = os.path.join(path, error_excel_filename) + # 保存 DataFrame 到 Excel + df.to_excel(save_error_path, index=False) + + return { + 'success_count': res['success_count'], + 'failed_count': len(failed_records), + 'duplicate_count': res['duplicate_count'], + 'original_count': res['original_count'], + 'error_excel_filename': error_excel_filename, + } + + return await asyncio.to_thread(inner) + + +@router.get("/download-fail-info/{filename}") +async def download_excel(filename: str, trans: Trans): + """ + 根据文件路径下载 Excel 文件 + """ + file_path = os.path.join(path, filename) + + # 检查文件是否存在 + if not os.path.exists(file_path): + raise HTTPException(404, "File Not Exists") + + # 检查文件是否是 Excel 文件 + if not filename.endswith('_error.xlsx'): + raise HTTPException(400, "Only support _error.xlsx") + + # 获取文件名 + filename = os.path.basename(file_path) + + # 返回文件 + return FileResponse( + path=file_path, + filename=filename, + media_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + ) diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py index 13291e56..a32f6700 100644 --- a/backend/apps/data_training/curd/data_training.py +++ b/backend/apps/data_training/curd/data_training.py @@ -146,6 +146,13 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid: def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): + # 基本验证 + if not info.question or not info.question.strip(): + raise Exception(trans("i18n_data_training.question_cannot_be_empty")) + + if not info.description or not info.description.strip(): + raise Exception(trans("i18n_data_training.description_cannot_be_empty")) + create_time = datetime.datetime.now() if info.datasource is None and info.advanced_application is None: if oid == 1: @@ -189,6 +196,13 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): + # 基本验证 + if not info.question or not info.question.strip(): + raise Exception(trans("i18n_data_training.question_cannot_be_empty")) + + if not info.description or not info.description.strip(): + raise Exception(trans("i18n_data_training.description_cannot_be_empty")) + if info.datasource is None and info.advanced_application is None: if oid == 1: raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none")) @@ -234,6 +248,208 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans return info.id +def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo], oid: int, trans: Trans): + """ + 批量创建数据训练记录 + Args: + session: 数据库会话 + info_list: DataTrainingInfo对象列表 + oid: 组织ID + trans: 翻译对象 + Returns: + dict: 包含成功数量、失败记录和统计信息的结果字典 + """ + if not info_list: + return { + 'success_count': 0, + 'failed_records': [], + 'duplicate_count': 0, + 'original_count': 0, + 'deduplicated_count': 0 + } + + create_time = datetime.datetime.now() + failed_records = [] # 存储失败的数据和原因 + success_count = 0 + inserted_ids = [] # 存储成功插入的ID + + # 第一步:数据去重 + unique_records = {} + duplicate_records = [] # 存储重复的数据 + + for info in info_list: + # 创建唯一标识:问题 + 数据源名称 + 高级应用名称 + unique_key = ( + info.question.strip().lower() if info.question else "", + info.datasource_name.strip().lower() if info.datasource_name else "", + info.advanced_application_name.strip().lower() if info.advanced_application_name else "" + ) + + if unique_key in unique_records: + # 如果是重复数据,记录到重复列表中 + duplicate_records.append(info) + else: + unique_records[unique_key] = info + + # 将去重后的数据转换为列表 + deduplicated_list = list(unique_records.values()) + + # 预加载数据源名称到ID的映射(CoreDatasource需要判断oid) + datasource_name_to_id = {} + datasource_stmt = select(CoreDatasource.id, CoreDatasource.name).where(CoreDatasource.oid == oid) + datasource_result = session.execute(datasource_stmt).all() + for ds in datasource_result: + datasource_name_to_id[ds.name.strip()] = ds.id + + # 只有在oid=1时才预加载高级应用名称到ID的映射 + assistant_name_to_id = {} + if oid == 1: + assistant_stmt = select(AssistantModel.id, AssistantModel.name).where(AssistantModel.type == 1) + assistant_result = session.execute(assistant_stmt).all() + for assistant in assistant_result: + assistant_name_to_id[assistant.name.strip()] = assistant.id + + # 验证和准备数据 + valid_records = [] + for info in deduplicated_list: + error_messages = [] + + # 基本验证 + if not info.question or not info.question.strip(): + error_messages.append(trans("i18n_data_training.question_cannot_be_empty")) + + if not info.description or not info.description.strip(): + error_messages.append(trans("i18n_data_training.description_cannot_be_empty")) + + # 数据源验证 + datasource_id = None + if info.datasource_name and info.datasource_name.strip(): + if info.datasource_name.strip() in datasource_name_to_id: + datasource_id = datasource_name_to_id[info.datasource_name.strip()] + else: + error_messages.append(trans("i18n_data_training.datasource_not_found").format(info.datasource_name)) + + # 高级应用验证(只有在oid=1时才需要) + advanced_application_id = None + if oid == 1 and info.advanced_application_name and info.advanced_application_name.strip(): + if info.advanced_application_name.strip() in assistant_name_to_id: + advanced_application_id = assistant_name_to_id[info.advanced_application_name.strip()] + else: + error_messages.append( + trans("i18n_data_training.advanced_application_not_found").format(info.advanced_application_name)) + + # 检查数据源和高级应用不能同时为空 + if oid == 1: + if not datasource_id and not advanced_application_id: + error_messages.append(trans("i18n_data_training.datasource_assistant_cannot_be_none")) + else: + if not datasource_id: + error_messages.append(trans("i18n_data_training.datasource_cannot_be_none")) + + # 如果有错误,添加到失败列表 + if error_messages: + # 返回原始的info对象,不包含转换后的ID + failed_records.append({ + 'data': info, # 直接返回原始传入的数据 + 'errors': error_messages + }) + continue + + # 检查数据库中是否已存在重复记录 + stmt = select(DataTraining.id).where( + and_( + DataTraining.question == info.question.strip(), + DataTraining.oid == oid + ) + ) + + # 根据oid决定重复检查条件 + if oid == 1: + if datasource_id is not None and advanced_application_id is not None: + stmt = stmt.where( + or_( + DataTraining.datasource == datasource_id, + DataTraining.advanced_application == advanced_application_id + ) + ) + elif datasource_id is not None: + stmt = stmt.where(DataTraining.datasource == datasource_id) + elif advanced_application_id is not None: + stmt = stmt.where(DataTraining.advanced_application == advanced_application_id) + else: + # oid != 1时,只检查数据源 + if datasource_id is not None: + stmt = stmt.where(DataTraining.datasource == datasource_id) + + exists = session.query(stmt.exists()).scalar() + + if exists: + # 返回原始的info对象 + failed_records.append({ + 'data': info, # 直接返回原始传入的数据 + 'errors': [trans("i18n_data_training.exists_in_db")] + }) + continue + + # 验证通过,添加到有效记录 + valid_records.append({ + 'info': info, + 'datasource_id': datasource_id, + 'advanced_application_id': advanced_application_id + }) + + # 批量插入有效记录 + if valid_records: + data_training_objects = [] + for record in valid_records: + info = record['info'] + data_training = DataTraining( + question=info.question.strip(), + description=info.description.strip(), + oid=oid, + datasource=record['datasource_id'], + advanced_application=record['advanced_application_id'] if oid == 1 else None, # 只有oid=1才设置高级应用 + create_time=create_time, + enabled=info.enabled if info.enabled is not None else True + ) + data_training_objects.append(data_training) + + try: + # 批量插入 + session.bulk_save_objects(data_training_objects) + session.commit() + + # 获取插入的ID(需要刷新对象) + for obj in data_training_objects: + session.refresh(obj) + inserted_ids.append(obj.id) + success_count += 1 + + # 批量处理embedding + if inserted_ids: + run_save_data_training_embeddings(inserted_ids) + + except Exception as e: + session.rollback() + # 将所有的有效记录标记为失败 + for record in valid_records: + # 返回原始的info对象 + failed_records.append({ + 'data': record['info'], # 直接返回原始传入的数据 + 'errors': [str(e)] + }) + success_count = 0 + + # 返回结果,包含去重统计信息 + return { + 'success_count': success_count, + 'failed_records': failed_records, + 'duplicate_count': len(duplicate_records), + 'original_count': len(info_list), + 'deduplicated_count': len(deduplicated_list) + } + + def delete_training(session: SessionDep, ids: list[int]): stmt = delete(DataTraining).where(and_(DataTraining.id.in_(ids))) session.execute(stmt) diff --git a/backend/locales/en.json b/backend/locales/en.json index d519de8c..10bdd082 100644 --- a/backend/locales/en.json +++ b/backend/locales/en.json @@ -58,7 +58,12 @@ "problem_description": "Problem Description", "sample_sql": "Sample SQL", "effective_data_sources": "Effective Data Sources", - "advanced_application": "Advanced Application" + "advanced_application": "Advanced Application", + "error_info": "Error Information", + "question_cannot_be_empty": "Question cannot be empty", + "description_cannot_be_empty": "Sample SQL cannot be empty", + "datasource_not_found": "Datasource not found", + "advanced_application_not_found": "Advanced application not found" }, "i18n_custom_prompt": { "exists_in_db": "Template name already exists", diff --git a/backend/locales/ko-KR.json b/backend/locales/ko-KR.json index 4e3a4e6c..c9c339af 100644 --- a/backend/locales/ko-KR.json +++ b/backend/locales/ko-KR.json @@ -58,7 +58,12 @@ "problem_description": "문제 설명", "sample_sql": "예시 SQL", "effective_data_sources": "유효 데이터 소스", - "advanced_application": "고급 애플리케이션" + "advanced_application": "고급 애플리케이션", + "error_info": "오류 정보", + "question_cannot_be_empty": "질문은 비울 수 없습니다", + "description_cannot_be_empty": "예시 SQL은 비울 수 없습니다", + "datasource_not_found": "데이터 소스를 찾을 수 없음", + "advanced_application_not_found": "고급 애플리케이션을 찾을 수 없음" }, "i18n_custom_prompt": { "exists_in_db": "템플릿 이름이 이미 존재합니다", diff --git a/backend/locales/zh-CN.json b/backend/locales/zh-CN.json index 113c9029..577ef8c2 100644 --- a/backend/locales/zh-CN.json +++ b/backend/locales/zh-CN.json @@ -58,7 +58,12 @@ "problem_description": "问题描述", "sample_sql": "示例 SQL", "effective_data_sources": "生效数据源", - "advanced_application": "高级应用" + "advanced_application": "高级应用", + "error_info": "错误信息", + "question_cannot_be_empty": "问题不能为空", + "description_cannot_be_empty": "示例 SQL 不能为空", + "datasource_not_found": "找不到数据源", + "advanced_application_not_found": "找不到高级应用" }, "i18n_custom_prompt": { "exists_in_db": "模版名称已存在", diff --git a/frontend/src/api/training.ts b/frontend/src/api/training.ts index ebc68e69..9b247ae2 100644 --- a/frontend/src/api/training.ts +++ b/frontend/src/api/training.ts @@ -15,4 +15,9 @@ export const trainingApi = { responseType: 'blob', requestOptions: { customError: true }, }), + downloadError: (path: any) => + request.get(`/system/data-training/download-fail-info/${path}`, { + responseType: 'blob', + requestOptions: { customError: true }, + }), } diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index 2a582627..23fefa9b 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -44,7 +44,9 @@ "sql_statement": "SQL Statement", "edit_training_data": "Edit SQL Sample", "all_236_terms": "Export all {msg} sample SQL records?", - "sales_this_year": "Do you want to delete the SQL Sample: {msg}?" + "sales_this_year": "Do you want to delete the SQL Sample: {msg}?", + "upload_success": "Import Successful", + "upload_failed": "Import successful: {success} records. Failed: {fail} records. For details, see: {fail_info}" }, "professional": { "cannot_be_repeated": "Term name, synonyms cannot be repeated", diff --git a/frontend/src/i18n/ko-KR.json b/frontend/src/i18n/ko-KR.json index 5093166f..a3d3269f 100644 --- a/frontend/src/i18n/ko-KR.json +++ b/frontend/src/i18n/ko-KR.json @@ -44,7 +44,9 @@ "sql_statement": "SQL 문", "edit_training_data": "예제 SQL 편집", "all_236_terms": "모든 {msg}개의 예시 SQL 기록을 내보내시겠습니까?", - "sales_this_year": "예제 SQL을 삭제하시겠습니까: {msg}?" + "sales_this_year": "예제 SQL을 삭제하시겠습니까: {msg}?", + "upload_success": "가져오기 성공", + "upload_failed": "성공: {success}건, 실패: {fail}건. 자세한 내용은 다음을 참조하세요: {fail_info}" }, "professional": { "cannot_be_repeated": "용어 이름과 동의어는 중복될 수 없습니다", diff --git a/frontend/src/i18n/zh-CN.json b/frontend/src/i18n/zh-CN.json index 5f2147a6..95f1ce5b 100644 --- a/frontend/src/i18n/zh-CN.json +++ b/frontend/src/i18n/zh-CN.json @@ -44,7 +44,9 @@ "sql_statement": "SQL 语句", "edit_training_data": "编辑示例 SQL", "all_236_terms": "是否导出全部 {msg} 条示例 SQL?", - "sales_this_year": "是否删除示例 SQL:{msg}?" + "sales_this_year": "是否删除示例 SQL:{msg}?", + "upload_success": "导入成功", + "upload_failed": "导入成功 {success} 条,失败 {fail} 条,失败信息详见:{fail_info}" }, "professional": { "cannot_be_repeated": "术语名称,同义词不能重复", diff --git a/frontend/src/views/system/training/index.vue b/frontend/src/views/system/training/index.vue index f4775c98..b45bdfb9 100644 --- a/frontend/src/views/system/training/index.vue +++ b/frontend/src/views/system/training/index.vue @@ -16,6 +16,10 @@ import { useUserStore } from '@/stores/user' import { useI18n } from 'vue-i18n' import { cloneDeep } from 'lodash-es' import { getAdvancedApplicationList } from '@/api/embedded.ts' +import { genFileId } from 'element-plus' + +import type { UploadInstance, UploadProps, UploadRawFile } from 'element-plus' +import { useCache } from '@/utils/useCache.ts' interface Form { id?: string | null @@ -33,7 +37,7 @@ const keywords = ref('') const oldKeywords = ref('') const searchLoading = ref(false) const { copy } = useClipboard({ legacy: true }) - +const { wsCache } = useCache() const options = ref([]) const adv_options = ref([]) const selectable = () => { @@ -86,6 +90,94 @@ const cancelDelete = () => { isIndeterminate.value = false } +const upload = () => {} + +const uploadRef = ref() +const uploadLoading = ref(false) + +const token = wsCache.get('user.token') +const headers = ref({ 'X-SQLBOT-TOKEN': `Bearer ${token}` }) +const getUploadURL = import.meta.env.VITE_API_BASE_URL + '/system/data-training/uploadExcel' + +const handleExceed: UploadProps['onExceed'] = (files) => { + uploadRef.value!.clearFiles() + const file = files[0] as UploadRawFile + file.uid = genFileId() + uploadRef.value!.handleStart(file) +} + +const beforeUpload = (rawFile: any) => { + if (rawFile.size / 1024 / 1024 > 50) { + ElMessage.error(t('common.not_exceed_50mb')) + return false + } + uploadLoading.value = true + return true +} +const onSuccess = (response: any) => { + uploadRef.value!.clearFiles() + search() + + if (response?.data?.failed_count > 0 && response?.data?.error_excel_filename) { + ElMessage.error( + t('training.upload_failed', { + success: response.data.success_count, + fail: response.data.failed_count, + fail_info: response.data.error_excel_filename, + }) + ) + trainingApi + .downloadError(response.data.error_excel_filename) + .then((res) => { + const blob = new Blob([res], { + type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + }) + const link = document.createElement('a') + link.href = URL.createObjectURL(blob) + link.download = `${t('training.data_training')}_error.xlsx` + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + }) + .catch(async (error) => { + if (error.response) { + try { + let text = await error.response.data.text() + try { + text = JSON.parse(text) + } finally { + ElMessage({ + message: text, + type: 'error', + showClose: true, + }) + } + } catch (e) { + console.error('Error processing error response:', e) + } + } else { + console.error('Other error:', error) + ElMessage({ + message: error, + type: 'error', + showClose: true, + }) + } + }) + .finally(() => { + uploadLoading.value = false + }) + } else { + ElMessage.success(t('training.upload_success')) + uploadLoading.value = false + } +} + +const onError = () => { + uploadLoading.value = false + uploadRef.value!.clearFiles() +} + const exportExcel = () => { ElMessageBox.confirm(t('training.all_236_terms', { msg: pageInfo.total }), { confirmButtonType: 'primary', @@ -360,13 +452,13 @@ const onRowFormClose = () => { {{ $t('professional.export_all') }} - - - {{ $t('user.batch_import') }} - + + + + {{ $t('user.batch_import') }} + +