From d7462daf1d32f14b07e0c740e3088bcb80270360 Mon Sep 17 00:00:00 2001 From: ulleo Date: Tue, 18 Nov 2025 18:40:03 +0800 Subject: [PATCH] feat: import Sample SQL --- .../apps/data_training/curd/data_training.py | 178 ++++++------------ 1 file changed, 62 insertions(+), 116 deletions(-) diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py index ede18941..bc8ddc22 100644 --- a/backend/apps/data_training/curd/data_training.py +++ b/backend/apps/data_training/curd/data_training.py @@ -146,6 +146,9 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid: def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): + """ + 创建单个数据训练记录 + """ # 基本验证 if not info.question or not info.question.strip(): raise Exception(trans("i18n_data_training.question_cannot_be_empty")) @@ -154,45 +157,56 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans raise Exception(trans("i18n_data_training.description_cannot_be_empty")) create_time = datetime.datetime.now() + + # 检查数据源和高级应用不能同时为空 if info.datasource is None and info.advanced_application is None: if oid == 1: raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none")) else: raise Exception(trans("i18n_data_training.datasource_cannot_be_none")) - parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid, - datasource=info.datasource, enabled=info.enabled, - advanced_application=info.advanced_application) - - stmt = select(DataTraining.id).where(and_(DataTraining.question == info.question, DataTraining.oid == oid)) + # 检查重复记录 + stmt = select(DataTraining.id).where( + and_(DataTraining.question == info.question.strip(), DataTraining.oid == oid) + ) if info.datasource is not None and info.advanced_application is not None: stmt = stmt.where( - or_(DataTraining.datasource == info.datasource, - DataTraining.advanced_application == info.advanced_application)) + or_( + DataTraining.datasource == info.datasource, + DataTraining.advanced_application == info.advanced_application + ) + ) elif info.datasource is not None and info.advanced_application is None: - stmt = stmt.where(and_(DataTraining.datasource == info.datasource)) + stmt = stmt.where(DataTraining.datasource == info.datasource) elif info.datasource is None and info.advanced_application is not None: - stmt = stmt.where(and_(DataTraining.advanced_application == info.advanced_application)) + stmt = stmt.where(DataTraining.advanced_application == info.advanced_application) exists = session.query(stmt.exists()).scalar() if exists: raise Exception(trans("i18n_data_training.exists_in_db")) - result = DataTraining(**parent.model_dump()) + # 创建记录 + data_training = DataTraining( + question=info.question.strip(), + description=info.description.strip(), + oid=oid, + datasource=info.datasource, + advanced_application=info.advanced_application, + create_time=create_time, + enabled=info.enabled if info.enabled is not None else True + ) - session.add(parent) + session.add(data_training) session.flush() - session.refresh(parent) - - result.id = parent.id + session.refresh(data_training) session.commit() - # embedding - run_save_data_training_embeddings([result.id]) + # 处理embedding + run_save_data_training_embeddings([data_training.id]) - return result.id + return data_training.id def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): @@ -250,14 +264,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo], oid: int, trans: Trans): """ - 批量创建数据训练记录 - Args: - session: 数据库会话 - info_list: DataTrainingInfo对象列表 - oid: 组织ID - trans: 翻译对象 - Returns: - dict: 包含成功数量、失败记录和统计信息的结果字典 + 批量创建数据训练记录(复用单条插入逻辑) """ if not info_list: return { @@ -268,17 +275,16 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] 'deduplicated_count': 0 } - create_time = datetime.datetime.now() - failed_records = [] # 存储失败的数据和原因 + failed_records = [] success_count = 0 - inserted_ids = [] # 存储成功插入的ID + inserted_ids = [] # 第一步:数据去重 unique_records = {} - duplicate_records = [] # 存储重复的数据 + duplicate_records = [] for info in info_list: - # 创建唯一标识:问题 + 数据源名称 + 高级应用名称 + # 创建唯一标识 unique_key = ( info.question.strip().lower() if info.question else "", info.datasource_name.strip().lower() if info.datasource_name else "", @@ -286,7 +292,6 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] ) if unique_key in unique_records: - # 如果是重复数据,记录到重复列表中 duplicate_records.append(info) else: unique_records[unique_key] = info @@ -294,14 +299,13 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] # 将去重后的数据转换为列表 deduplicated_list = list(unique_records.values()) - # 预加载数据源名称到ID的映射(CoreDatasource需要判断oid) + # 预加载数据源和高级应用名称到ID的映射 datasource_name_to_id = {} datasource_stmt = select(CoreDatasource.id, CoreDatasource.name).where(CoreDatasource.oid == oid) datasource_result = session.execute(datasource_stmt).all() for ds in datasource_result: datasource_name_to_id[ds.name.strip()] = ds.id - # 只有在oid=1时才预加载高级应用名称到ID的映射 assistant_name_to_id = {} if oid == 1: assistant_stmt = select(AssistantModel.id, AssistantModel.name).where(AssistantModel.type == 1) @@ -309,7 +313,7 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] for assistant in assistant_result: assistant_name_to_id[assistant.name.strip()] = assistant.id - # 验证和准备数据 + # 验证和转换数据 valid_records = [] for info in deduplicated_list: error_messages = [] @@ -321,7 +325,7 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] if not info.description or not info.description.strip(): error_messages.append(trans("i18n_data_training.description_cannot_be_empty")) - # 数据源验证 + # 数据源验证和转换 datasource_id = None if info.datasource_name and info.datasource_name.strip(): if info.datasource_name.strip() in datasource_name_to_id: @@ -329,7 +333,7 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] else: error_messages.append(trans("i18n_data_training.datasource_not_found").format(info.datasource_name)) - # 高级应用验证(只有在oid=1时才需要) + # 高级应用验证和转换 advanced_application_id = None if oid == 1 and info.advanced_application_name and info.advanced_application_name.strip(): if info.advanced_application_name.strip() in assistant_name_to_id: @@ -346,101 +350,43 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo] if not datasource_id: error_messages.append(trans("i18n_data_training.datasource_cannot_be_none")) - # 如果有错误,添加到失败列表 if error_messages: - # 返回原始的info对象,不包含转换后的ID failed_records.append({ - 'data': info, # 直接返回原始传入的数据 + 'data': info, 'errors': error_messages }) continue - # 检查数据库中是否已存在重复记录 - stmt = select(DataTraining.id).where( - and_( - DataTraining.question == info.question.strip(), - DataTraining.oid == oid - ) + # 创建处理后的DataTrainingInfo对象 + processed_info = DataTrainingInfo( + question=info.question.strip(), + description=info.description.strip(), + datasource=datasource_id, + datasource_name=info.datasource_name, + advanced_application=advanced_application_id, + advanced_application_name=info.advanced_application_name, + enabled=info.enabled if info.enabled is not None else True ) - # 根据oid决定重复检查条件 - if oid == 1: - if datasource_id is not None and advanced_application_id is not None: - stmt = stmt.where( - or_( - DataTraining.datasource == datasource_id, - DataTraining.advanced_application == advanced_application_id - ) - ) - elif datasource_id is not None: - stmt = stmt.where(DataTraining.datasource == datasource_id) - elif advanced_application_id is not None: - stmt = stmt.where(DataTraining.advanced_application == advanced_application_id) - else: - # oid != 1时,只检查数据源 - if datasource_id is not None: - stmt = stmt.where(DataTraining.datasource == datasource_id) - - exists = session.query(stmt.exists()).scalar() - - if exists: - # 返回原始的info对象 - failed_records.append({ - 'data': info, # 直接返回原始传入的数据 - 'errors': [trans("i18n_data_training.exists_in_db")] - }) - continue - - # 验证通过,添加到有效记录 - valid_records.append({ - 'info': info, - 'datasource_id': datasource_id, - 'advanced_application_id': advanced_application_id - }) + valid_records.append(processed_info) - # 批量插入有效记录 + # 使用事务处理有效记录 if valid_records: - data_training_objects = [] - for record in valid_records: - info = record['info'] - data_training = DataTraining( - question=info.question.strip(), - description=info.description.strip(), - oid=oid, - datasource=record['datasource_id'], - advanced_application=record['advanced_application_id'] if oid == 1 else None, # 只有oid=1才设置高级应用 - create_time=create_time, - enabled=info.enabled if info.enabled is not None else True - ) - data_training_objects.append(data_training) - - try: - # 批量插入 - session.bulk_save_objects(data_training_objects, return_defaults=True) - session.commit() + for info in valid_records: + try: + # 直接复用create_training方法 + training_id = create_training(session, info, oid, trans) + inserted_ids.append(training_id) + success_count += 1 - # 获取插入的ID - for obj in data_training_objects: - if obj.id is not None: # 确保ID已经被赋值 - inserted_ids.append(obj.id) - success_count += 1 - - except Exception as e: - session.rollback() - # 将所有的有效记录标记为失败 - for record in valid_records: - # 返回原始的info对象 + except Exception as e: + # 如果单条插入失败,回滚当前记录 + session.rollback() failed_records.append({ - 'data': record['info'], # 直接返回原始传入的数据 + 'data': info, 'errors': [str(e)] }) - success_count = 0 - - # 批量处理embedding - if success_count > 0 and inserted_ids: - run_save_data_training_embeddings(inserted_ids) - # 返回结果,包含去重统计信息 return { 'success_count': success_count, 'failed_records': failed_records,