Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion backend/apps/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import platform
import urllib.parse
from decimal import Decimal
from datetime import timedelta
from decimal import Decimal
from typing import Optional

import oracledb
Expand Down Expand Up @@ -32,6 +32,8 @@
from fastapi import HTTPException
from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data_by_http
from common.core.config import settings
import sqlglot
from sqlglot import expressions as exp

try:
if os.path.exists(settings.ORACLE_CLIENT_PATH):
Expand Down Expand Up @@ -464,6 +466,9 @@ def convert_value(value):
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False):
while sql.endswith(';'):
sql = sql[:-1]
# check execute sql only contain read operations
if not check_sql_read(sql):
raise ValueError(f"SQL can only contain read operations")

db = DB.get_db(ds.type)
if db.connect_type == ConnectType.sqlalchemy:
Expand Down Expand Up @@ -569,3 +574,29 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
except Exception as ex:
raise Exception(str(ex))


def check_sql_read(sql: str, dialect=None):
try:

statements = sqlglot.parse(sql, dialect=dialect)

if not statements:
raise ValueError("Parse SQL Error")

write_types = (
exp.Insert, exp.Update, exp.Delete,
exp.Create, exp.Drop, exp.Alter,
exp.Merge, exp.Command
)

for stmt in statements:
if stmt is None:
continue
if isinstance(stmt, write_types):
return False

return True

except Exception as e:
raise ValueError(f"Parse SQL Error: {e}")
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dependencies = [
"redshift-connector>=2.1.8",
"elasticsearch[requests] (>=7.10,<8.0)",
"ldap3>=2.9.1",
"sqlglot>=28.6.0",
]

[project.optional-dependencies]
Expand Down