Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,35 @@ def __init__(self):

self.database_engine = DatabaseEngine.DATABRICKS

@property
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
return [
# Session and system variables
"CURRENT_CATALOG",
"CURRENT_DATABASE",
"CURRENT_USER",
"SESSION_USER",
"CURRENT_ROLE",
"CURRENT_QUERY",
"CURRENT_WAREHOUSE",
"SESSION_ID",
# System metadata functions
"DATABASE",
"USER",
# Potentially unsafe built-in functions
"CURRENT_USER",
"SESSION_USER",
"SYSTEM",
"SHOW",
"DESCRIBE",
"EXPLAIN",
"SET",
"SHOW TABLES",
"SHOW COLUMNS",
"SHOW DATABASES",
]

async def query_execution(
self,
sql_query: Annotated[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,48 @@ def __init__(self):

self.database_engine = DatabaseEngine.SNOWFLAKE

@property
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
return [
"CURRENT_CLIENT",
"CURRENT_IP_ADDRESS",
"CURRENT_REGION",
"CURRENT_VERSION",
"ALL_USER_NAMES",
"CURRENT_ACCOUNT",
"CURRENT_ACCOUNT_NAME",
"CURRENT_ORGANIZATION_NAME",
"CURRENT_ROLE",
"CURRENT_AVAILABLE_ROLES",
"CURRENT_SECONDARY_ROLES",
"CURRENT_SESSION",
"CURRENT_STATEMENT",
"CURRENT_TRANSACTION",
"CURRENT_USER",
"GETVARIABLE",
"LAST_QUERY_ID",
"LAST_TRANSACTION",
"CURRENT_DATABASE",
"CURRENT_ROLE_TYPE",
"CURRENT_SCHEMA",
"CURRENT_SCHEMAS",
"CURRENT_WAREHOUSE",
"INVOKER_ROLE",
"INVOKER_SHARE",
"IS_APPLICATION_ROLE_IN_SESSION",
"IS_DATABASE_ROLE_IN_SESSION",
"IS_GRANTED_TO_INVOKER_ROLE",
"IS_INSTANCE_ROLE_IN_SESSION",
"IS_ROLE_IN_SESSION",
"POLICY_CONTEXT",
"CURRENT_SESSION_USER",
"SESSION_ID",
"QUERY_START_TIME",
"QUERY_ELAPSED_TIME",
"QUERY_MEMORY_USAGE",
]

async def query_execution(
self,
sql_query: Annotated[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from text_2_sql_core.connectors.factory import ConnectorFactory
import asyncio
import sqlglot
from sqlglot.expressions import Parameter, Select, Identifier
from abc import ABC, abstractmethod
from jinja2 import Template
import json
Expand All @@ -29,6 +30,12 @@ def __init__(self):

self.database_engine = None

@property
@abstractmethod
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
pass

@abstractmethod
async def query_execution(
self,
Expand Down Expand Up @@ -123,11 +130,49 @@ async def query_validation(
"""Validate the SQL query."""
try:
logging.info("Validating SQL Query: %s", sql_query)
sqlglot.transpile(
parsed_queries = sqlglot.parse(
sql_query,
read=self.database_engine.value.lower(),
error_level=sqlglot.ErrorLevel.RAISE,
)

expressions = []
identifiers = []

def handle_node(node):
if isinstance(node, Select):
# Extract expressions
for expr in node.expressions:
expressions.append(expr)
elif isinstance(node, Identifier):
# Extract identifiers
identifiers.append(node.this)

detected_invalid_identifiers = []

for parsed_query in parsed_queries:
for node in parsed_query.walk():
handle_node(node)

for token in expressions + identifiers:
if isinstance(token, Parameter):
identifier = token.this.this
else:
identifier = str(token).strip("()").upper()

if identifier in self.invalid_identifiers:
logging.warning("Detected invalid identifier: %s", identifier)
detected_invalid_identifiers.append(identifier)

if len(detected_invalid_identifiers) > 0:
logging.error(
"SQL Query contains invalid identifiers: %s",
detected_invalid_identifiers,
)
return (
"SQL Query contains invalid identifiers: %s"
% detected_invalid_identifiers
)

except sqlglot.errors.ParseError as e:
logging.error("SQL Query is invalid: %s", e.errors)
return e.errors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,45 @@ def __init__(self):

self.database_engine = DatabaseEngine.TSQL

@property
def invalid_identifiers(self) -> list[str]:
"""Get the invalid identifiers upon which a sql query is rejected."""
return [
"CONNECTIONS",
"CPU_BUSY",
"CURSOR_ROWS",
"DATEFIRST",
"DBTS",
"ERROR",
"FETCH_STATUS",
"IDENTITY",
"IDLE",
"IO_BUSY",
"LANGID",
"LANGUAGE",
"LOCK_TIMEOUT",
"MAX_CONNECTIONS",
"MAX_PRECISION",
"NESTLEVEL",
"OPTIONS",
"PACK_RECEIVED",
"PACK_SENT",
"PACKET_ERRORS",
"PROCID",
"REMSERVER",
"ROWCOUNT",
"SERVERNAME",
"SERVICENAME",
"SPID",
"TEXTSIZE",
"TIMETICKS",
"TOTAL_ERRORS",
"TOTAL_READ",
"TOTAL_WRITE",
"TRANCOUNT",
"VERSION",
]

async def query_execution(
self,
sql_query: Annotated[
Expand Down
Loading