From 190f5d21e8e4d7369acbb3a883b3ae7a3d7f1b29 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 3 Jan 2025 14:45:49 +0000 Subject: [PATCH 1/8] Add postgres support --- .../text_2_sql_schema_store.py | 2 + text_2_sql/autogen/pyproject.toml | 2 +- text_2_sql/text_2_sql_core/pyproject.toml | 4 + .../connectors/postgresql_sql.py | 99 +++++++++++++++++++ .../data_dictionary_creator.py | 2 + .../postgresql_data_dictionary_creator.py | 87 ++++++++++++++++ .../src/text_2_sql_core/utils/database.py | 1 + 7 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py diff --git a/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py b/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py index bbfcf310..ee4d7055 100644 --- a/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py +++ b/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py @@ -65,6 +65,8 @@ def excluded_fields_for_database_engine(self): engine_specific_fields = ["Database"] elif self.database_engine == DatabaseEngine.DATABRICKS: engine_specific_fields = ["Catalog"] + elif self.database_engine == DatabaseEngine.POSTGRESQL: + engine_specific_fields = ["Database"] return [ field diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index c72e5d6a..977d6e9e 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "autogen-ext[azure,openai]==0.4.0.dev11", "grpcio>=1.68.1", "pyyaml>=6.0.2", - "text_2_sql_core[snowflake,databricks]", + "text_2_sql_core[snowflake,databricks,postgresql]", ] [dependency-groups] diff --git a/text_2_sql/text_2_sql_core/pyproject.toml b/text_2_sql/text_2_sql_core/pyproject.toml index f5af69b7..9aeaa500 100644 --- a/text_2_sql/text_2_sql_core/pyproject.toml +++ b/text_2_sql/text_2_sql_core/pyproject.toml @@ -46,6 +46,10 @@ databricks = [ "databricks-sql-connector>=3.0.1", "pyarrow>=14.0.2,<17", ] +postgresql = [ + "psycopg>=3.2.3", +] + [build-system] requires = ["hatchling"] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py new file mode 100644 index 00000000..809ca620 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from text_2_sql_core.connectors.sql import SqlConnector +import psycopg +from typing import Annotated +import os +import logging +import json + +from text_2_sql_core.utils.database import DatabaseEngine + + +class PostgresqlSqlConnector(SqlConnector): + def __init__(self): + super().__init__() + + self.database_engine = DatabaseEngine.POSTGRESQL + + async def query_execution( + self, + sql_query: Annotated[str, "The SQL query to run against the database."], + cast_to: any = None, + limit=None, + ) -> list[dict]: + """Run the SQL query against the PostgreSQL database asynchronously. + + Args: + ---- + sql_query (str): The SQL query to run against the database. + + Returns: + ------- + list[dict]: The results of the SQL query. + """ + logging.info(f"Running query: {sql_query}") + results = [] + connection_string = os.environ["Text2Sql__DatabaseConnectionString"] + + # Establish an asynchronous connection to the PostgreSQL database + async with psycopg.AsyncConnection.connect(connection_string) as conn: + # Create an asynchronous cursor + async with conn.cursor() as cursor: + await cursor.execute(sql_query) + + # Fetch column names + columns = [column[0] for column in cursor.description] + + # Fetch rows based on the limit + if limit is not None: + rows = await cursor.fetchmany(limit) + else: + rows = await cursor.fetchall() + + # Process the rows + for row in rows: + if cast_to: + results.append(cast_to.from_sql_row(row, columns)) + else: + results.append(dict(zip(columns, row))) + + logging.debug("Results: %s", results) + return results + + async def get_entity_schemas( + self, + text: Annotated[ + str, + "The text to run a semantic search against. Relevant entities will be returned.", + ], + excluded_entities: Annotated[ + list[str], + "The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.", + ] = [], + as_json: bool = True, + ) -> str: + """Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. + + Args: + ---- + text (str): The text to run the search against. + + Returns: + str: The schema of the views or tables in JSON format. + """ + + schemas = await self.ai_search_connector.get_entity_schemas( + text, excluded_entities + ) + + for schema in schemas: + schema["SelectFromEntity"] = ".".join([schema["Schema"], schema["Entity"]]) + + del schema["Entity"] + del schema["Schema"] + + if as_json: + return json.dumps(schemas, default=str) + else: + return schemas diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py index c3d1cf1f..f6f050b4 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py @@ -758,6 +758,8 @@ def excluded_fields_for_database_engine(self): engine_specific_fields = ["Database"] elif self.database_engine == DatabaseEngine.DATABRICKS: engine_specific_fields = ["Catalog"] + elif self.database_engine == DatabaseEngine.POSTGRESQL: + engine_specific_fields = ["Database"] else: engine_specific_fields = [] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py new file mode 100644 index 00000000..c020a8dc --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from text_2_sql_core.data_dictionary.data_dictionary_creator import ( + DataDictionaryCreator, + EntityItem, +) +import os + +from text_2_sql_core.utils.database import DatabaseEngine +from text_2_sql_core.connectors.postgresql_sql import PostgresSqlConnector + + +class PostgresqlDataDictionaryCreator(DataDictionaryCreator): + def __init__(self, **kwargs): + """A method to initialize the DataDictionaryCreator class.""" + super().__init__(**kwargs) + + self.database = os.environ["Text2Sql__DatabaseName"] + self.database_engine = DatabaseEngine.POSTGRESQL + + self.sql_connector = PostgresSqlConnector() + + @property + def extract_table_entities_sql_query(self) -> str: + """A property to extract table entities from a PostgreSQL database.""" + return """SELECT + t.table_name AS entity, + t.table_schema AS entity_schema, + pg_catalog.col_description(c.oid, 0) AS definition + FROM + information_schema.tables t + JOIN + pg_catalog.pg_class c ON c.relname = t.table_name + AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = t.table_schema) + WHERE + t.table_type = 'BASE TABLE' + ORDER BY entity_schema, entity;""" + + @property + def extract_view_entities_sql_query(self) -> str: + """A property to extract view entities from a PostgreSQL database.""" + return """SELECT + v.view_name AS entity, + v.table_schema AS entity_schema, + pg_catalog.col_description(c.oid, 0) AS definition + FROM + information_schema.views v + JOIN + pg_catalog.pg_class c ON c.relname = v.view_name + AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = v.table_schema) + ORDER BY entity_schema, entity;""" + + def extract_columns_sql_query(self, entity: EntityItem) -> str: + """A property to extract column information from a PostgreSQL database.""" + return f"""SELECT + c.column_name AS name, + c.data_type AS data_type, + pg_catalog.col_description(t.oid, c.ordinal_position) AS definition + FROM + information_schema.columns c + JOIN + pg_catalog.pg_class t ON t.relname = c.table_name + AND t.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema) + WHERE + c.table_schema = '{entity.entity_schema}' + AND c.table_name = '{entity.name}' + ORDER BY c.ordinal_position;""" + + @property + def extract_entity_relationships_sql_query(self) -> str: + """A property to extract entity relationships from a PostgreSQL database.""" + return """SELECT + tc.table_schema AS entity_schema, + tc.table_name AS entity, + rc.unique_constraint_schema AS foreign_entity_schema, + rc.unique_constraint_name AS foreign_entity_constraint, + rc.constraint_name AS foreign_key_constraint + FROM + information_schema.referential_constraints rc + JOIN + information_schema.table_constraints tc + ON rc.constraint_schema = tc.constraint_schema + AND rc.constraint_name = tc.constraint_name + WHERE + tc.constraint_type = 'FOREIGN KEY' + ORDER BY + entity_schema, entity, foreign_entity_schema, foreign_entity_constraint;""" diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py index 4ee1b796..75a4969d 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py @@ -7,3 +7,4 @@ class DatabaseEngine(StrEnum): DATABRICKS = "DATABRICKS" SNOWFLAKE = "SNOWFLAKE" TSQL = "TSQL" + POSTGRESQL = "POSTGRESQL" From 825d2c8f457becbd1415b33543848f8eef8370ba Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 3 Jan 2025 15:03:06 +0000 Subject: [PATCH 2/8] Store excluded fields in --- .../text_2_sql_schema_store.py | 25 +++---------------- .../connectors/databricks_sql.py | 7 +++++- .../src/text_2_sql_core/connectors/factory.py | 6 +++++ .../connectors/postgresql_sql.py | 7 +++++- .../connectors/snowflake_sql.py | 10 +++++++- .../src/text_2_sql_core/connectors/sql.py | 17 +++++++++++++ .../text_2_sql_core/connectors/tsql_sql.py | 7 +++++- .../data_dictionary_creator.py | 17 +------------ .../src/text_2_sql_core/utils/database.py | 8 ++++++ 9 files changed, 63 insertions(+), 41 deletions(-) diff --git a/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py b/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py index ee4d7055..28d59b3f 100644 --- a/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py +++ b/deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py @@ -26,6 +26,7 @@ ) import os from text_2_sql_core.utils.database import DatabaseEngine +from text_2_sql_core.connectors.factory import ConnectorFactory class Text2SqlSchemaStoreAISearch(AISearch): @@ -49,31 +50,13 @@ def __init__( os.environ["Text2Sql__DatabaseEngine"].upper() ] + self.database_connector = ConnectorFactory.get_database_connector() + if single_data_dictionary_file: self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY else: self.parsing_mode = BlobIndexerParsingMode.JSON - @property - def excluded_fields_for_database_engine(self): - """A method to get the excluded fields for the database engine.""" - - all_engine_specific_fields = ["Warehouse", "Database", "Catalog"] - if self.database_engine == DatabaseEngine.SNOWFLAKE: - engine_specific_fields = ["Warehouse", "Database"] - elif self.database_engine == DatabaseEngine.TSQL: - engine_specific_fields = ["Database"] - elif self.database_engine == DatabaseEngine.DATABRICKS: - engine_specific_fields = ["Catalog"] - elif self.database_engine == DatabaseEngine.POSTGRESQL: - engine_specific_fields = ["Database"] - - return [ - field - for field in all_engine_specific_fields - if field not in engine_specific_fields - ] - def get_index_fields(self) -> list[SearchableField]: """This function returns the index fields for sql index. @@ -198,7 +181,7 @@ def get_index_fields(self) -> list[SearchableField]: fields = [ field for field in fields - if field.name not in self.excluded_fields_for_database_engine + if field.name not in self.database_connector.excluded_engine_specific_fields ] return fields diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index 14c99ac2..e345dd38 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -8,7 +8,7 @@ import logging import json -from text_2_sql_core.utils.database import DatabaseEngine +from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields class DatabricksSqlConnector(SqlConnector): @@ -17,6 +17,11 @@ def __init__(self): self.database_engine = DatabaseEngine.DATABRICKS + @property + def engine_specific_fields(self) -> list[str]: + """Get the engine specific fields.""" + return [DatabaseEngineSpecificFields.CATALOG] + async def query_execution( self, sql_query: Annotated[ diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py index b2c4f6e2..609906d7 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py @@ -25,6 +25,12 @@ def get_database_connector(): from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector return TSQLSqlConnector() + elif os.environ["Text2Sql__DatabaseEngine"].upper() == "POSTGRESQL": + from text_2_sql_core.connectors.postgresql_sql import ( + PostgresqlSqlConnector, + ) + + return PostgresqlSqlConnector() else: raise ValueError( f"""Database engine { diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py index 809ca620..3142ade3 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py @@ -7,7 +7,7 @@ import logging import json -from text_2_sql_core.utils.database import DatabaseEngine +from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields class PostgresqlSqlConnector(SqlConnector): @@ -16,6 +16,11 @@ def __init__(self): self.database_engine = DatabaseEngine.POSTGRESQL + @property + def engine_specific_fields(self) -> list[str]: + """Get the engine specific fields.""" + return [DatabaseEngineSpecificFields.DATABASE] + async def query_execution( self, sql_query: Annotated[str, "The SQL query to run against the database."], diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py index c8da9c26..b606912d 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py @@ -8,7 +8,7 @@ import logging import json -from text_2_sql_core.utils.database import DatabaseEngine +from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields class SnowflakeSqlConnector(SqlConnector): @@ -17,6 +17,14 @@ def __init__(self): self.database_engine = DatabaseEngine.SNOWFLAKE + @property + def engine_specific_fields(self) -> list[str]: + """Get the engine specific fields.""" + return [ + DatabaseEngineSpecificFields.WAREHOUSE, + DatabaseEngineSpecificFields.DATABASE, + ] + async def query_execution( self, sql_query: Annotated[ diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index c367cd1b..c1d19038 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from jinja2 import Template import json +from text_2_sql_core.utils.database import DatabaseEngineSpecificFields class SqlConnector(ABC): @@ -29,6 +30,22 @@ def __init__(self): self.database_engine = None + @abstractmethod + @property + def engine_specific_fields(self) -> list[str]: + """Get the engine specific fields.""" + pass + + @property + def excluded_engine_specific_fields(self): + """A method to get the excluded fields for the database engine.""" + + return [ + field.value.capitalize() + for field in DatabaseEngineSpecificFields + if field not in self.engine_specific_fields + ] + @abstractmethod async def query_execution( self, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py index e494bb18..e6919a87 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py @@ -7,7 +7,7 @@ import logging import json -from text_2_sql_core.utils.database import DatabaseEngine +from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields class TSQLSqlConnector(SqlConnector): @@ -16,6 +16,11 @@ def __init__(self): self.database_engine = DatabaseEngine.TSQL + @property + def engine_specific_fields(self) -> list[str]: + """Get the engine specific fields.""" + return [DatabaseEngineSpecificFields.DATABASE] + async def query_execution( self, sql_query: Annotated[ diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py index f6f050b4..7aaf30a4 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py @@ -11,7 +11,6 @@ import random import re import networkx as nx -from text_2_sql_core.utils.database import DatabaseEngine from tenacity import retry, stop_after_attempt, wait_exponential from text_2_sql_core.connectors.open_ai import OpenAIConnector @@ -751,23 +750,9 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem: def excluded_fields_for_database_engine(self): """A method to get the excluded fields for the database engine.""" - all_engine_specific_fields = ["Warehouse", "Database", "Catalog"] - if self.database_engine == DatabaseEngine.SNOWFLAKE: - engine_specific_fields = ["Warehouse", "Database"] - elif self.database_engine == DatabaseEngine.TSQL: - engine_specific_fields = ["Database"] - elif self.database_engine == DatabaseEngine.DATABRICKS: - engine_specific_fields = ["Catalog"] - elif self.database_engine == DatabaseEngine.POSTGRESQL: - engine_specific_fields = ["Database"] - else: - engine_specific_fields = [] - # Determine top-level fields to exclude filtered_entitiy_specific_fields = { - field.lower(): ... - for field in all_engine_specific_fields - if field not in engine_specific_fields + field.lower(): ... for field in self.excluded_engine_specific_fields } if filtered_entitiy_specific_fields: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py index 75a4969d..168515b4 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py @@ -8,3 +8,11 @@ class DatabaseEngine(StrEnum): SNOWFLAKE = "SNOWFLAKE" TSQL = "TSQL" POSTGRESQL = "POSTGRESQL" + + +class DatabaseEngineSpecificFields(StrEnum): + """An enumeration to represent the database engine specific fields.""" + + WAREHOUSE = "Warehouse" + DATABASE = "Database" + CATALOG = "Catalog" From f9d335cdf848decffce04035a3ed9ef16fa8d695 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 3 Jan 2025 16:21:45 +0000 Subject: [PATCH 3/8] Fix bad property --- .../text_2_sql_core/src/text_2_sql_core/connectors/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 7b3dffca..675a8346 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -37,8 +37,8 @@ def invalid_identifiers(self) -> list[str]: """Get the invalid identifiers upon which a sql query is rejected.""" pass - @abstractmethod @property + @abstractmethod def engine_specific_fields(self) -> list[str]: """Get the engine specific fields.""" pass From e6d62a3b51de8d8b4e40079450eb88dc1a8a4589 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 3 Jan 2025 16:26:36 +0000 Subject: [PATCH 4/8] Add postgres identifiers --- .../connectors/postgresql_sql.py | 22 +++++++++++++++++++ .../src/text_2_sql_core/connectors/sql.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py index 3142ade3..522554b0 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py @@ -21,6 +21,28 @@ def engine_specific_fields(self) -> list[str]: """Get the engine specific fields.""" return [DatabaseEngineSpecificFields.DATABASE] + @property + def invalid_identifiers(self) -> list[str]: + """Get the invalid identifiers upon which a sql query is rejected.""" + + return [ + "CURRENT_USER", # Returns the name of the current user + "SESSION_USER", # Returns the name of the user that initiated the session + "USER", # Returns the name of the current user + "CURRENT_ROLE", # Returns the current role + "CURRENT_DATABASE", # Returns the name of the current database + "CURRENT_SCHEMA()", # Returns the name of the current schema + "CURRENT_SETTING()", # Returns the value of a specified configuration parameter + "PG_CURRENT_XACT_ID()", # Returns the current transaction ID + # (if the extension is enabled) Provides a view of query statistics + "PG_STAT_STATEMENTS()", + "PG_SLEEP()", # Delays execution by the specified number of seconds + "CLIENT_ADDR()", # Returns the IP address of the client (from pg_stat_activity) + "CLIENT_HOSTNAME()", # Returns the hostname of the client (from pg_stat_activity) + "PGP_SYM_DECRYPT()", # (from pgcrypto extension) Symmetric decryption function + "PGP_PUB_DECRYPT()", # (from pgcrypto extension) Asymmetric decryption function + ] + async def query_execution( self, sql_query: Annotated[str, "The SQL query to run against the database."], diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 675a8346..21a93780 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -172,7 +172,7 @@ def handle_node(node): for token in expressions + identifiers: if isinstance(token, Parameter): - identifier = token.this.this + identifier = str(token.this.this).upper() else: identifier = str(token).strip("()").upper() From 98c9ee5f6857ab121e0705926e6893d60395d596 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 3 Jan 2025 16:29:38 +0000 Subject: [PATCH 5/8] Update lock file --- uv.lock | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/uv.lock b/uv.lock index 01d51250..45c72cc2 100644 --- a/uv.lock +++ b/uv.lock @@ -365,7 +365,7 @@ dependencies = [ { name = "autogen-ext", extra = ["azure", "openai"] }, { name = "grpcio" }, { name = "pyyaml" }, - { name = "text-2-sql-core", extra = ["databricks", "snowflake"] }, + { name = "text-2-sql-core", extra = ["databricks", "postgresql", "snowflake"] }, ] [package.dev-dependencies] @@ -388,7 +388,7 @@ requires-dist = [ { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.0.dev11" }, { name = "grpcio", specifier = ">=1.68.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, - { name = "text-2-sql-core", extras = ["snowflake", "databricks"], editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["snowflake", "databricks", "postgresql"], editable = "text_2_sql/text_2_sql_core" }, ] [package.metadata.requires-dev] @@ -2481,6 +2481,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/91/87fa6f060e649b1e1a7b19a4f5869709fbf750b7c8c262ee776ec32f3028/psutil-6.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be", size = 254228 }, ] +[[package]] +name = "psycopg" +version = "3.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/ad/7ce016ae63e231575df0498d2395d15f005f05e32d3a2d439038e1bd0851/psycopg-3.2.3.tar.gz", hash = "sha256:a5764f67c27bec8bfac85764d23c534af2c27b893550377e37ce59c12aac47a2", size = 155550 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/21/534b8f5bd9734b7a2fcd3a16b1ee82ef6cad81a4796e95ebf4e0c6a24119/psycopg-3.2.3-py3-none-any.whl", hash = "sha256:644d3973fe26908c73d4be746074f6e5224b03c1101d302d9a53bf565ad64907", size = 197934 }, +] + [[package]] name = "ptyprocess" version = "0.7.0" @@ -3376,6 +3389,9 @@ databricks = [ { name = "databricks-sql-connector" }, { name = "pyarrow" }, ] +postgresql = [ + { name = "psycopg" }, +] snowflake = [ { name = "snowflake-connector-python" }, ] @@ -3403,6 +3419,7 @@ requires-dist = [ { name = "numpy", specifier = "<2.0.0" }, { name = "openai", specifier = ">=1.55.3" }, { name = "pandas", specifier = ">=2.2.3" }, + { name = "psycopg", marker = "extra == 'postgresql'", specifier = ">=3.2.3" }, { name = "pyarrow", marker = "extra == 'databricks'", specifier = ">=14.0.2,<17" }, { name = "pydantic", specifier = ">=2.10.2" }, { name = "python-dotenv", specifier = ">=1.0.1" }, From 482a08d58117378243e45e1421ef707bdf491d86 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 3 Jan 2025 16:36:44 +0000 Subject: [PATCH 6/8] Update postgresql commands --- .../postgresql_data_dictionary_creator.py | 91 ++++++++++++------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py index c020a8dc..e2327d16 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py @@ -24,64 +24,87 @@ def __init__(self, **kwargs): def extract_table_entities_sql_query(self) -> str: """A property to extract table entities from a PostgreSQL database.""" return """SELECT - t.table_name AS entity, - t.table_schema AS entity_schema, - pg_catalog.col_description(c.oid, 0) AS definition + t.table_name AS "Entity", + t.table_schema AS "EntitySchema", + pg_catalog.obj_description(c.oid, 'pg_class') AS "Definition" FROM information_schema.tables t JOIN pg_catalog.pg_class c ON c.relname = t.table_name AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = t.table_schema) + LEFT JOIN + pg_catalog.pg_description pd ON pd.objoid = c.oid WHERE t.table_type = 'BASE TABLE' - ORDER BY entity_schema, entity;""" + AND pd.objsubid = 0 -- 0 indicates the table description, not column descriptions + ORDER BY + "EntitySchema", "Entity";""" @property def extract_view_entities_sql_query(self) -> str: """A property to extract view entities from a PostgreSQL database.""" return """SELECT - v.view_name AS entity, - v.table_schema AS entity_schema, - pg_catalog.col_description(c.oid, 0) AS definition + v.viewname AS "Entity", + v.schemaname AS "EntitySchema", + pg_catalog.obj_description(c.oid, 'pg_class') AS "Definition" FROM - information_schema.views v + pg_catalog.pg_views v JOIN - pg_catalog.pg_class c ON c.relname = v.view_name - AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = v.table_schema) - ORDER BY entity_schema, entity;""" + pg_catalog.pg_class c ON c.relname = v.viewname + AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = v.schemaname) + LEFT JOIN + pg_catalog.pg_description pd ON pd.objoid = c.oid + WHERE + pd.objsubid = 0 -- 0 indicates the view description, not a column description + ORDER BY + "EntitySchema", "Entity";""" def extract_columns_sql_query(self, entity: EntityItem) -> str: """A property to extract column information from a PostgreSQL database.""" return f"""SELECT - c.column_name AS name, - c.data_type AS data_type, - pg_catalog.col_description(t.oid, c.ordinal_position) AS definition + c.attname AS "Name", + t.typname AS "DataType", + pgd.description AS "Definition" FROM - information_schema.columns c - JOIN - pg_catalog.pg_class t ON t.relname = c.table_name - AND t.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema) + pg_attribute c + INNER JOIN + pg_class tbl ON c.attrelid = tbl.oid + INNER JOIN + pg_namespace ns ON tbl.relnamespace = ns.oid + INNER JOIN + pg_type t ON c.atttypid = t.oid + LEFT JOIN + pg_description pgd ON pgd.objoid = tbl.oid AND pgd.objsubid = c.attnum WHERE - c.table_schema = '{entity.entity_schema}' - AND c.table_name = '{entity.name}' - ORDER BY c.ordinal_position;""" + ns.nspname = '{entity.entity_schema}' + AND tbl.relname = '{entity.name}' + AND c.attnum > 0 -- Exclude system columns + ORDER BY + c.attnum;""" @property def extract_entity_relationships_sql_query(self) -> str: """A property to extract entity relationships from a PostgreSQL database.""" return """SELECT - tc.table_schema AS entity_schema, - tc.table_name AS entity, - rc.unique_constraint_schema AS foreign_entity_schema, - rc.unique_constraint_name AS foreign_entity_constraint, - rc.constraint_name AS foreign_key_constraint + fk_schema.nspname AS "EntitySchema", + fk_tab.relname AS "Entity", + pk_schema.nspname AS "ForeignEntitySchema", + pk_tab.relname AS "ForeignEntity", + fk_col.attname AS "Column", + pk_col.attname AS "ForeignColumn" FROM - information_schema.referential_constraints rc - JOIN - information_schema.table_constraints tc - ON rc.constraint_schema = tc.constraint_schema - AND rc.constraint_name = tc.constraint_name - WHERE - tc.constraint_type = 'FOREIGN KEY' + pg_constraint fk + INNER JOIN + pg_attribute fk_col ON fk.conrelid = fk_col.attrelid AND fk.attnum = fk_col.attnum + INNER JOIN + pg_class fk_tab ON fk.conrelid = fk_tab.oid + INNER JOIN + pg_namespace fk_schema ON fk_tab.relnamespace = fk_schema.oid + INNER JOIN + pg_class pk_tab ON fk.confrelid = pk_tab.oid + INNER JOIN + pg_namespace pk_schema ON pk_tab.relnamespace = pk_schema.oid + INNER JOIN + pg_attribute pk_col ON fk.confrelid = pk_col.attrelid AND fk.confkey[1] = pk_col.attnum ORDER BY - entity_schema, entity, foreign_entity_schema, foreign_entity_constraint;""" + "EntitySchema", "Entity", "ForeignEntitySchema", "ForeignEntity";""" From f885bd561296d5a89a6b32fe38af90057a97af13 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 3 Jan 2025 16:39:17 +0000 Subject: [PATCH 7/8] Make optional --- deploy_ai_search/pyproject.toml | 11 +++++++++++ text_2_sql/autogen/pyproject.toml | 13 ++++++++++++- uv.lock | 32 +++++++++++++++++++++++++++++-- 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/deploy_ai_search/pyproject.toml b/deploy_ai_search/pyproject.toml index fca3d6aa..1e98e816 100644 --- a/deploy_ai_search/pyproject.toml +++ b/deploy_ai_search/pyproject.toml @@ -26,3 +26,14 @@ dev = [ [tool.uv.sources] text_2_sql_core = { workspace = true } + +[project.optional-dependencies] +snowflake = [ + "text_2_sql_core[snowflake]", +] +databricks = [ + "text_2_sql_core[databricks]", +] +postgresql = [ + "text_2_sql_core[postgresql]", +] diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index 977d6e9e..68515028 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "autogen-ext[azure,openai]==0.4.0.dev11", "grpcio>=1.68.1", "pyyaml>=6.0.2", - "text_2_sql_core[snowflake,databricks,postgresql]", + "text_2_sql_core", ] [dependency-groups] @@ -28,3 +28,14 @@ dev = [ [tool.uv.sources] text_2_sql_core = { workspace = true } + +[project.optional-dependencies] +snowflake = [ + "text_2_sql_core[snowflake]", +] +databricks = [ + "text_2_sql_core[databricks]", +] +postgresql = [ + "text_2_sql_core[postgresql]", +] diff --git a/uv.lock b/uv.lock index 45c72cc2..e88441fd 100644 --- a/uv.lock +++ b/uv.lock @@ -365,7 +365,18 @@ dependencies = [ { name = "autogen-ext", extra = ["azure", "openai"] }, { name = "grpcio" }, { name = "pyyaml" }, - { name = "text-2-sql-core", extra = ["databricks", "postgresql", "snowflake"] }, + { name = "text-2-sql-core" }, +] + +[package.optional-dependencies] +databricks = [ + { name = "text-2-sql-core", extra = ["databricks"] }, +] +postgresql = [ + { name = "text-2-sql-core", extra = ["postgresql"] }, +] +snowflake = [ + { name = "text-2-sql-core", extra = ["snowflake"] }, ] [package.dev-dependencies] @@ -388,7 +399,10 @@ requires-dist = [ { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.0.dev11" }, { name = "grpcio", specifier = ">=1.68.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, - { name = "text-2-sql-core", extras = ["snowflake", "databricks", "postgresql"], editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["databricks"], marker = "extra == 'databricks'", editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["postgresql"], marker = "extra == 'postgresql'", editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["snowflake"], marker = "extra == 'snowflake'", editable = "text_2_sql/text_2_sql_core" }, ] [package.metadata.requires-dev] @@ -929,6 +943,17 @@ dependencies = [ { name = "text-2-sql-core" }, ] +[package.optional-dependencies] +databricks = [ + { name = "text-2-sql-core", extra = ["databricks"] }, +] +postgresql = [ + { name = "text-2-sql-core", extra = ["postgresql"] }, +] +snowflake = [ + { name = "text-2-sql-core", extra = ["snowflake"] }, +] + [package.dev-dependencies] dev = [ { name = "black" }, @@ -948,6 +973,9 @@ requires-dist = [ { name = "azure-storage-blob", specifier = ">=12.24.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "text-2-sql-core", editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["databricks"], marker = "extra == 'databricks'", editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["postgresql"], marker = "extra == 'postgresql'", editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["snowflake"], marker = "extra == 'snowflake'", editable = "text_2_sql/text_2_sql_core" }, ] [package.metadata.requires-dev] From 3ac3ce1832da7613d3bb42e15c702e1cf4753a13 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 6 Jan 2025 12:29:11 +0000 Subject: [PATCH 8/8] Update postgresql --- .../postgresql_data_dictionary_creator.py | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py index e2327d16..d5116594 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/postgresql_data_dictionary_creator.py @@ -7,18 +7,19 @@ import os from text_2_sql_core.utils.database import DatabaseEngine -from text_2_sql_core.connectors.postgresql_sql import PostgresSqlConnector +from text_2_sql_core.connectors.postgresql_sql import PostgresqlSqlConnector class PostgresqlDataDictionaryCreator(DataDictionaryCreator): def __init__(self, **kwargs): """A method to initialize the DataDictionaryCreator class.""" - super().__init__(**kwargs) + excluded_schemas = ["information_schema", "pg_catalog"] + super().__init__(excluded_schemas=excluded_schemas, **kwargs) self.database = os.environ["Text2Sql__DatabaseName"] self.database_engine = DatabaseEngine.POSTGRESQL - self.sql_connector = PostgresSqlConnector() + self.sql_connector = PostgresqlSqlConnector() @property def extract_table_entities_sql_query(self) -> str: @@ -29,14 +30,16 @@ def extract_table_entities_sql_query(self) -> str: pg_catalog.obj_description(c.oid, 'pg_class') AS "Definition" FROM information_schema.tables t - JOIN - pg_catalog.pg_class c ON c.relname = t.table_name - AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = t.table_schema) LEFT JOIN - pg_catalog.pg_description pd ON pd.objoid = c.oid + pg_catalog.pg_class c + ON c.relname = t.table_name + AND c.relnamespace = ( + SELECT oid + FROM pg_catalog.pg_namespace + WHERE nspname = t.table_schema + ) WHERE t.table_type = 'BASE TABLE' - AND pd.objsubid = 0 -- 0 indicates the table description, not column descriptions ORDER BY "EntitySchema", "Entity";""" @@ -44,18 +47,19 @@ def extract_table_entities_sql_query(self) -> str: def extract_view_entities_sql_query(self) -> str: """A property to extract view entities from a PostgreSQL database.""" return """SELECT - v.viewname AS "Entity", - v.schemaname AS "EntitySchema", + v.table_name AS "Entity", + v.table_schema AS "EntitySchema", pg_catalog.obj_description(c.oid, 'pg_class') AS "Definition" FROM - pg_catalog.pg_views v - JOIN - pg_catalog.pg_class c ON c.relname = v.viewname - AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = v.schemaname) + information_schema.views v LEFT JOIN - pg_catalog.pg_description pd ON pd.objoid = c.oid - WHERE - pd.objsubid = 0 -- 0 indicates the view description, not a column description + pg_catalog.pg_class c + ON c.relname = v.table_name + AND c.relnamespace = ( + SELECT oid + FROM pg_catalog.pg_namespace + WHERE nspname = v.table_schema + ) ORDER BY "EntitySchema", "Entity";""" @@ -95,7 +99,7 @@ def extract_entity_relationships_sql_query(self) -> str: FROM pg_constraint fk INNER JOIN - pg_attribute fk_col ON fk.conrelid = fk_col.attrelid AND fk.attnum = fk_col.attnum + pg_attribute fk_col ON fk.conrelid = fk_col.attrelid AND fk.conkey[1] = fk_col.attnum INNER JOIN pg_class fk_tab ON fk.conrelid = fk_tab.oid INNER JOIN