diff --git a/backend/apps/db/constant.py b/backend/apps/db/constant.py index ddd456c8..3b6a66c1 100644 --- a/backend/apps/db/constant.py +++ b/backend/apps/db/constant.py @@ -3,6 +3,8 @@ from enum import Enum +from common.utils.utils import equals_ignore_case + class ConnectType(Enum): sqlalchemy = ('sqlalchemy') @@ -37,7 +39,8 @@ def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, tem @classmethod def get_db(cls, type, default_if_none=False): for db in cls: - if db.type == type: + """ if db.type == type: """ + if equals_ignore_case(db.type, type): return db if default_if_none: return DB.pg diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index 8aaebeaa..e307ac18 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -24,35 +24,35 @@ from apps.system.crud.assistant import get_ds_engine from apps.system.schemas.system_schema import AssistantOutDsSchema from common.core.deps import Trans -from common.utils.utils import SQLBotLogUtil +from common.utils.utils import SQLBotLogUtil, equals_ignore_case from fastapi import HTTPException from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data_by_http def get_uri(ds: CoreDatasource) -> str: - conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if not equals_ignore_case(ds.type, "excel") else get_engine_config() return get_uri_from_config(ds.type, conf) def get_uri_from_config(type: str, conf: DatasourceConf) -> str: db_url: str - if type == "mysql": + if equals_ignore_case(type, "mysql"): if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" - elif type == "sqlServer": + elif equals_ignore_case(type, "sqlServer"): if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" - elif type == "pg" or type == "excel": + elif equals_ignore_case(type, "pg", "excel"): if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" - elif type == "oracle": - if conf.mode == "service_name": + elif equals_ignore_case(type, "oracle"): + if equals_ignore_case(conf.mode, "service_name"): if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={conf.database}&{conf.extraJdbc}" else: @@ -62,7 +62,7 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str: db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" - elif type == "ck": + elif equals_ignore_case(type, "ck"): if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: @@ -87,7 +87,7 @@ def get_extra_config(conf: DatasourceConf): def get_origin_connect(type: str, conf: DatasourceConf): extra_config_dict = get_extra_config(conf) - if type == "sqlServer": + if equals_ignore_case(type, "sqlServer"): return pymssql.connect( server=conf.host, port=str(conf.port), @@ -102,12 +102,12 @@ def get_origin_connect(type: str, conf: DatasourceConf): # use sqlalchemy def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: - conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if not equals_ignore_case(ds.type, "excel") else get_engine_config() if conf.timeout is None: conf.timeout = timeout if timeout > 0: conf.timeout = timeout - if ds.type == "pg": + if equals_ignore_case(ds.type, "pg"): if conf.dbSchema is not None and conf.dbSchema != "": engine = create_engine(get_uri(ds), connect_args={"options": f"-c search_path={urllib.parse.quote(conf.dbSchema)}", @@ -117,10 +117,10 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout) - elif ds.type == 'sqlServer': + elif equals_ignore_case(ds.type, 'sqlServer'): engine = create_engine('mssql+pymssql://', creator=lambda: get_origin_connect(ds.type, conf), pool_timeout=conf.timeout) - elif ds.type == 'oracle': + elif equals_ignore_case(ds.type, 'oracle'): engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) else: # mysql, ck @@ -152,7 +152,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs else: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) extra_config_dict = get_extra_config(conf) - if ds.type == 'dm': + if equals_ignore_case(ds.type, 'dm'): with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: try: @@ -164,7 +164,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False - elif ds.type == 'doris' or ds.type == "starrocks": + elif equals_ignore_case(ds.type, 'doris', 'starrocks'): with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -177,7 +177,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False - elif ds.type == 'redshift': + elif equals_ignore_case(ds.type, 'redshift'): with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, @@ -191,7 +191,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False - elif ds.type == 'kingbase': + elif equals_ignore_case(ds.type, 'kingbase'): with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, @@ -205,7 +205,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False - elif ds.type == 'es': + elif equals_ignore_case(ds.type, 'es'): es_conn = get_es_connect(conf) if es_conn.ping(): SQLBotLogUtil.info("success") @@ -233,7 +233,7 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema): conf = None if isinstance(ds, CoreDatasource): conf = DatasourceConf( - **json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + **json.loads(aes_decrypt(ds.configuration))) if not equals_ignore_case(ds.type, "excel") else get_engine_config() if isinstance(ds, AssistantOutDsSchema): conf = DatasourceConf() conf.host = ds.host @@ -253,20 +253,20 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema): version = res[0][0] else: extra_config_dict = get_extra_config(conf) - if ds.type == 'dm': + if equals_ignore_case(ds.type, 'dm'): with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: cursor.execute(sql, timeout=10, **extra_config_dict) res = cursor.fetchall() version = res[0][0] - elif ds.type == 'doris' or ds.type == "starrocks": + elif equals_ignore_case(ds.type, 'doris', 'starrocks'): with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() version = res[0][0] - elif ds.type == 'redshift' or ds.type == 'es': + elif equals_ignore_case(ds.type, 'redshift', 'es'): version = '' except Exception as e: print(e) @@ -280,11 +280,11 @@ def get_schema(ds: CoreDatasource): if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: sql: str = '' - if ds.type == "sqlServer": + if equals_ignore_case(ds.type, "sqlServer"): sql = """select name from sys.schemas""" - elif ds.type == "pg" or ds.type == "excel": + elif equals_ignore_case(ds.type, "pg", "excel"): sql = """SELECT nspname FROM pg_namespace""" - elif ds.type == "oracle": + elif equals_ignore_case(ds.type, "oracle"): sql = """select * from all_users""" with session.execute(text(sql)) as result: res = result.fetchall() @@ -292,14 +292,14 @@ def get_schema(ds: CoreDatasource): return res_list else: extra_config_dict = get_extra_config(conf) - if ds.type == 'dm': + if equals_ignore_case(ds.type, 'dm'): with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: cursor.execute("""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout) res = cursor.fetchall() res_list = [item[0] for item in res] return res_list - elif ds.type == 'redshift': + elif equals_ignore_case(ds.type, 'redshift'): with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -307,7 +307,7 @@ def get_schema(ds: CoreDatasource): res = cursor.fetchall() res_list = [item[0] for item in res] return res_list - elif ds.type == 'kingbase': + elif equals_ignore_case(ds.type, 'kingbase'): with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, options=f"-c statement_timeout={conf.timeout * 1000}", @@ -319,7 +319,7 @@ def get_schema(ds: CoreDatasource): def get_tables(ds: CoreDatasource): - conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if not equals_ignore_case(ds.type, "excel") else get_engine_config() db = DB.get_db(ds.type) sql, sql_param = get_table_sql(ds, conf, get_version(ds)) if db.connect_type == ConnectType.sqlalchemy: @@ -330,14 +330,14 @@ def get_tables(ds: CoreDatasource): return res_list else: extra_config_dict = get_extra_config(conf) - if ds.type == 'dm': + if equals_ignore_case(ds.type, 'dm'): with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: cursor.execute(sql, {"param": sql_param}, timeout=conf.timeout) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list - elif ds.type == 'doris' or ds.type == "starrocks": + elif equals_ignore_case(ds.type, 'doris', 'starrocks'): with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -345,7 +345,7 @@ def get_tables(ds: CoreDatasource): res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list - elif ds.type == 'redshift': + elif equals_ignore_case(ds.type, 'redshift'): with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -353,7 +353,7 @@ def get_tables(ds: CoreDatasource): res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list - elif ds.type == 'kingbase': + elif equals_ignore_case(ds.type, 'kingbase'): with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, options=f"-c statement_timeout={conf.timeout * 1000}", @@ -362,14 +362,14 @@ def get_tables(ds: CoreDatasource): res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list - elif ds.type == 'es': + elif equals_ignore_case(ds.type, 'es'): res = get_es_index(conf) res_list = [TableSchema(*item) for item in res] return res_list def get_fields(ds: CoreDatasource, table_name: str = None): - conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if not equals_ignore_case(ds.type, "excel") else get_engine_config() db = DB.get_db(ds.type) sql, p1, p2 = get_field_sql(ds, conf, table_name) if db.connect_type == ConnectType.sqlalchemy: @@ -380,14 +380,14 @@ def get_fields(ds: CoreDatasource, table_name: str = None): return res_list else: extra_config_dict = get_extra_config(conf) - if ds.type == 'dm': + if equals_ignore_case(ds.type, 'dm'): with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: cursor.execute(sql, {"param1": p1, "param2": p2}, timeout=conf.timeout) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list - elif ds.type == 'doris' or ds.type == "starrocks": + elif equals_ignore_case(ds.type, 'doris', 'starrocks'): with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -395,7 +395,7 @@ def get_fields(ds: CoreDatasource, table_name: str = None): res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list - elif ds.type == 'redshift': + elif equals_ignore_case(ds.type, 'redshift'): with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -403,7 +403,7 @@ def get_fields(ds: CoreDatasource, table_name: str = None): res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list - elif ds.type == 'kingbase': + elif equals_ignore_case(ds.type, 'kingbase'): with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, options=f"-c statement_timeout={conf.timeout * 1000}", @@ -412,7 +412,7 @@ def get_fields(ds: CoreDatasource, table_name: str = None): res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list - elif ds.type == 'es': + elif equals_ignore_case(ds.type, 'es'): res = get_es_fields(conf, table_name) res_list = [ColumnSchema(*item) for item in res] return res_list @@ -441,7 +441,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= else: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) extra_config_dict = get_extra_config(conf) - if ds.type == 'dm': + if equals_ignore_case(ds.type, 'dm'): with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: try: @@ -459,7 +459,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex)) - elif ds.type == 'doris' or ds.type == "starrocks": + elif equals_ignore_case(ds.type, 'doris', 'starrocks'): with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -478,7 +478,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex)) - elif ds.type == 'redshift': + elif equals_ignore_case(ds.type, 'redshift'): with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: @@ -497,7 +497,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex)) - elif ds.type == 'kingbase': + elif equals_ignore_case(ds.type, 'kingbase'): with psycopg2.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, options=f"-c statement_timeout={conf.timeout * 1000}", @@ -517,7 +517,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex)) - elif ds.type == 'es': + elif equals_ignore_case(ds.type, 'es'): try: res, columns = get_es_data_by_http(conf, sql) columns = [field.get('name') for field in columns] if origin_column else [field.get('name').lower() for diff --git a/backend/apps/db/db_sql.py b/backend/apps/db/db_sql.py index 46745bca..d217ce33 100644 --- a/backend/apps/db/db_sql.py +++ b/backend/apps/db/db_sql.py @@ -1,39 +1,40 @@ # Author: Junjun # Date: 2025/8/20 from apps.datasource.models.datasource import CoreDatasource, DatasourceConf +from common.utils.utils import equals_ignore_case def get_version_sql(ds: CoreDatasource, conf: DatasourceConf): - if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks": + if equals_ignore_case(ds.type, "mysql", "doris", "starrocks"): return """ SELECT VERSION() """ - elif ds.type == "sqlServer": + elif equals_ignore_case(ds.type, "sqlServer"): return """ select SERVERPROPERTY('ProductVersion') """ - elif ds.type == "pg" or ds.type == "kingbase" or ds.type == "excel": + elif equals_ignore_case(ds.type, "pg", "kingbase", "excel"): return """ SELECT current_setting('server_version') """ - elif ds.type == "oracle": + elif equals_ignore_case(ds.type, "oracle"): return """ SELECT version FROM v$instance """ - elif ds.type == "ck": + elif equals_ignore_case(ds.type, "ck"): return """ select version() """ - elif ds.type == 'dm': + elif equals_ignore_case(ds.type, "dm"): return """ SELECT * FROM v$version """ - elif ds.type == 'redshift': + elif equals_ignore_case(ds.type, "redshift"): return '' def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''): - if ds.type == "mysql": + if equals_ignore_case(ds.type, "mysql"): return """ SELECT TABLE_NAME, @@ -43,7 +44,7 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' WHERE TABLE_SCHEMA = :param """, conf.database - elif ds.type == "sqlServer": + elif equals_ignore_case(ds.type, "sqlServer"): return """ SELECT TABLE_NAME AS [TABLE_NAME], @@ -59,7 +60,7 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') AND t.TABLE_SCHEMA = :param """, conf.dbSchema - elif ds.type == "pg" or ds.type == "excel": + elif equals_ignore_case(ds.type, "pg", "excel"): return """ SELECT c.relname AS TABLE_NAME, COALESCE(d.description, obj_description(c.oid)) AS TABLE_COMMENT @@ -74,7 +75,7 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' AND c.relname NOT LIKE 'sql_%' ORDER BY c.relname \ """, conf.dbSchema - elif ds.type == "oracle": + elif equals_ignore_case(ds.type, "oracle"): return """ SELECT DISTINCT t.TABLE_NAME AS "TABLE_NAME", @@ -98,7 +99,7 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' AND c.OWNER = :param ORDER BY t.TABLE_NAME """, conf.dbSchema - elif ds.type == "ck": + elif equals_ignore_case(ds.type, "ck"): version = int(db_version.split('.')[0]) if version < 22: return """ @@ -116,14 +117,14 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' AND engine NOT IN ('Dictionary') ORDER BY name """, conf.database - elif ds.type == 'dm': + elif equals_ignore_case(ds.type, "dm"): return """ select table_name, comments from all_tab_comments where owner=:param AND (table_type = 'TABLE' or table_type = 'VIEW') """, conf.dbSchema - elif ds.type == 'redshift': + elif equals_ignore_case(ds.type, "redshift"): return """ SELECT relname AS TableName, @@ -134,7 +135,7 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' relkind in ('r','p', 'f') AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = %s) """, conf.dbSchema - elif ds.type == "doris" or ds.type == "starrocks": + elif equals_ignore_case(ds.type, "doris", "starrocks"): return """ SELECT TABLE_NAME, @@ -144,7 +145,7 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' WHERE TABLE_SCHEMA = %s """, conf.database - elif ds.type == "kingbase": + elif equals_ignore_case(ds.type, "kingbase"): return """ SELECT c.relname AS TABLE_NAME, COALESCE(d.description, obj_description(c.oid)) AS TABLE_COMMENT @@ -159,12 +160,12 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' AND c.relname NOT LIKE 'sql_%' ORDER BY c.relname \ """, conf.dbSchema - elif ds.type == "es": + elif equals_ignore_case(ds.type, "es"): return "", None def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = None): - if ds.type == "mysql": + if equals_ignore_case(ds.type, "mysql"): sql1 = """ SELECT COLUMN_NAME, @@ -177,7 +178,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.database, table_name - elif ds.type == "sqlServer": + elif equals_ignore_case(ds.type, "sqlServer"): sql1 = """ SELECT COLUMN_NAME AS [COLUMN_NAME], @@ -195,7 +196,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND C.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.dbSchema, table_name - elif ds.type == "pg" or ds.type == "excel": + elif equals_ignore_case(ds.type, "pg", "excel"): sql1 = """ SELECT a.attname AS COLUMN_NAME, pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, @@ -211,7 +212,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND c.relname = :param2" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.dbSchema, table_name - elif ds.type == "redshift": + elif equals_ignore_case(ds.type, "redshift"): sql1 = """ SELECT a.attname AS COLUMN_NAME, pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, @@ -227,7 +228,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND c.relname = %s" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.dbSchema, table_name - elif ds.type == "oracle": + elif equals_ignore_case(ds.type, "oracle"): sql1 = """ SELECT col.COLUMN_NAME AS "COLUMN_NAME", @@ -252,7 +253,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND col.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.dbSchema, table_name - elif ds.type == "ck": + elif equals_ignore_case(ds.type, "ck"): sql1 = """ SELECT name AS COLUMN_NAME, @@ -263,7 +264,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND table = :param2" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.database, table_name - elif ds.type == 'dm': + elif equals_ignore_case(ds.type, "dm"): sql1 = """ SELECT c.COLUMN_NAME AS "COLUMN_NAME", @@ -281,7 +282,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND c.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.dbSchema, table_name - elif ds.type == "doris" or ds.type == "starrocks": + elif equals_ignore_case(ds.type, "doris", "starrocks"): sql1 = """ SELECT COLUMN_NAME, @@ -294,7 +295,7 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND TABLE_NAME = %s" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.database, table_name - elif ds.type == "kingbase": + elif equals_ignore_case(ds.type, "kingbase"): sql1 = """ SELECT a.attname AS COLUMN_NAME, pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, @@ -310,5 +311,5 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No """ sql2 = " AND c.relname = '{1}'" if table_name is not None and table_name != "" else "" return sql1 + sql2, conf.dbSchema, table_name - elif ds.type == "es": + elif equals_ignore_case(ds.type, "es"): return "", None, None diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 58bf07a6..72168221 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -17,7 +17,7 @@ from common.core.db import engine from common.core.sqlbot_cache import cache from common.utils.aes_crypto import simple_aes_decrypt -from common.utils.utils import string_to_numeric_hash +from common.utils.utils import equals_ignore_case, string_to_numeric_hash @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id") @@ -236,18 +236,15 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine: conf.extraJdbc = '' from apps.db.db import get_uri_from_config uri = get_uri_from_config(ds.type, conf) - # if ds.type == "pg" and ds.db_schema: - # connect_args.update({"options": f"-c search_path={ds.db_schema}"}) - # engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10) - if ds.type == "pg" and ds.db_schema: + if equals_ignore_case(ds.type, "pg") and ds.db_schema: engine = create_engine(uri, connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}", "connect_timeout": timeout}, pool_timeout=timeout) - elif ds.type == 'sqlServer': + elif equals_ignore_case(ds.type, 'sqlServer'): engine = create_engine(uri, pool_timeout=timeout) - elif ds.type == 'oracle': + elif equals_ignore_case(ds.type, 'oracle'): engine = create_engine(uri, pool_timeout=timeout) else: diff --git a/backend/common/utils/utils.py b/backend/common/utils/utils.py index 9b2570d1..ee6b0964 100644 --- a/backend/common/utils/utils.py +++ b/backend/common/utils/utils.py @@ -263,3 +263,13 @@ def get_origin_from_referer(request: Request): SQLBotLogUtil.error(f"解析 Referer 出错: {e}") return referer + +def equals_ignore_case(str1: str, *args: str) -> bool: + if str1 is None: + return None in args + for arg in args: + if arg is None: + continue + if str1.casefold() == arg.casefold(): + return True + return False \ No newline at end of file