From 813441702b5b10206398534f3230dcff21200d55 Mon Sep 17 00:00:00 2001 From: junjun Date: Thu, 29 Jan 2026 13:51:06 +0800 Subject: [PATCH] fix: check sql only contain read operation #814 --- backend/apps/db/db.py | 33 ++++++++++++++++++++++++++++++++- backend/pyproject.toml | 1 + 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index fd58d212..2f3c8b8a 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -3,8 +3,8 @@ import os import platform import urllib.parse -from decimal import Decimal from datetime import timedelta +from decimal import Decimal from typing import Optional import oracledb @@ -32,6 +32,8 @@ from fastapi import HTTPException from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data_by_http from common.core.config import settings +import sqlglot +from sqlglot import expressions as exp try: if os.path.exists(settings.ORACLE_CLIENT_PATH): @@ -464,6 +466,9 @@ def convert_value(value): def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False): while sql.endswith(';'): sql = sql[:-1] + # check execute sql only contain read operations + if not check_sql_read(sql): + raise ValueError(f"SQL can only contain read operations") db = DB.get_db(ds.type) if db.connect_type == ConnectType.sqlalchemy: @@ -569,3 +574,29 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise Exception(str(ex)) + + +def check_sql_read(sql: str, dialect=None): + try: + + statements = sqlglot.parse(sql, dialect=dialect) + + if not statements: + raise ValueError("Parse SQL Error") + + write_types = ( + exp.Insert, exp.Update, exp.Delete, + exp.Create, exp.Drop, exp.Alter, + exp.Merge, exp.Command + ) + + for stmt in statements: + if stmt is None: + continue + if isinstance(stmt, write_types): + return False + + return True + + except Exception as e: + raise ValueError(f"Parse SQL Error: {e}") diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2f4e9d31..a133dc83 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "redshift-connector>=2.1.8", "elasticsearch[requests] (>=7.10,<8.0)", "ldap3>=2.9.1", + "sqlglot>=28.6.0", ] [project.optional-dependencies]