From 2522c1e92bd8e8e9bb2555e8ae1dfe5e69b4cc51 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 16 Jan 2025 17:33:56 +0000 Subject: [PATCH 1/3] Update naming --- text_2_sql/autogen/README.md | 2 +- .../parallel_query_solving_agent.py | 10 ++++- .../semantic_kernel/README.md | 4 +- .../vector_based_sql_plugin.py | 24 ++++++----- .../src/text_2_sql_core/connectors/sql.py | 40 +++++++++---------- .../custom_agents/sql_query_cache_agent.py | 25 +++++++----- 6 files changed, 61 insertions(+), 44 deletions(-) diff --git a/text_2_sql/autogen/README.md b/text_2_sql/autogen/README.md index 68e2da90..a3eb1dae 100644 --- a/text_2_sql/autogen/README.md +++ b/text_2_sql/autogen/README.md @@ -134,7 +134,7 @@ Each agent can be configured with specific parameters and prompts to optimize it ## Query Cache Implementation Details -The vector based with query cache uses the `fetch_queries_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first. +The vector based with query cache uses the `fetch_sql_queries_with_schemas_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first. If the score of the top result is higher than the defined threshold, the query will be executed against the target data source and the results included in the prompt. This allows us to prompt the LLM to evaluated whether it can use these results to answer the question, **without further SQL Query generation** to speed up the process. diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index c93db3e4..2fe79916 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -152,8 +152,14 @@ async def consume_inner_messages_from_agentic_flow( # Search for specific message types and add them to the final output object if isinstance(parsed_message, dict): # Check if the message contains pre-run results - if ("contains_pre_run_results" in parsed_message) and ( - parsed_message["contains_pre_run_results"] is True + if ( + "contains_cached_sql_queries_with_schemas_from_cache_database_results" + in parsed_message + ) and ( + parsed_message[ + "contains_cached_sql_queries_with_schemas_from_cache_database_results" + ] + is True ): logging.info("Contains pre-run results") for pre_run_sql_query, pre_run_result in parsed_message[ diff --git a/text_2_sql/previous_iterations/semantic_kernel/README.md b/text_2_sql/previous_iterations/semantic_kernel/README.md index c23396aa..8041c314 100644 --- a/text_2_sql/previous_iterations/semantic_kernel/README.md +++ b/text_2_sql/previous_iterations/semantic_kernel/README.md @@ -134,9 +134,9 @@ This method is called by the Semantic Kernel framework automatically, when instr The search text passed is vectorised against the entity level **Description** columns. A hybrid Semantic Reranking search is applied against the **EntityName**, **Entity**, **Columns/Name** fields. -#### fetch_queries_from_cache() +#### fetch_sql_queries_with_schemas_from_cache() -The vector based with query cache uses the `fetch_queries_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first. +The vector based with query cache uses the `fetch_sql_queries_with_schemas_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first. If the score of the top result is higher than the defined threshold, the query will be executed against the target data source and the results included in the prompt. This allows us to prompt the LLM to evaluated whether it can use these results to answer the question, **without further SQL Query generation** to speed up the process. diff --git a/text_2_sql/previous_iterations/semantic_kernel/plugins/vector_based_sql_plugin/vector_based_sql_plugin.py b/text_2_sql/previous_iterations/semantic_kernel/plugins/vector_based_sql_plugin/vector_based_sql_plugin.py index 8b1720ac..6b3c2cd7 100644 --- a/text_2_sql/previous_iterations/semantic_kernel/plugins/vector_based_sql_plugin/vector_based_sql_plugin.py +++ b/text_2_sql/previous_iterations/semantic_kernel/plugins/vector_based_sql_plugin/vector_based_sql_plugin.py @@ -137,7 +137,7 @@ async def fetch_schemas_from_store(self, search: str) -> list[dict]: return schemas - async def fetch_queries_from_cache(self, question: str) -> str: + async def fetch_sql_queries_with_schemas_from_cache(self, question: str) -> str: """Fetch the queries from the cache based on the question. Args: @@ -151,7 +151,7 @@ async def fetch_queries_from_cache(self, question: str) -> str: if not self.use_query_cache: return None - cached_schemas = await self.ai_search.run_ai_search_query( + sql_queries_with_schemas = await self.ai_search.run_ai_search_query( question, ["QuestionEmbedding"], ["Question", "SqlQueryDecomposition"], @@ -164,11 +164,11 @@ async def fetch_queries_from_cache(self, question: str) -> str: minimum_score=1.5, ) - if len(cached_schemas) == 0: + if len(sql_queries_with_schemas) == 0: return None else: database = os.environ["Text2Sql__DatabaseName"] - for entry in cached_schemas["SqlQueryDecomposition"]: + for entry in sql_queries_with_schemas["SqlQueryDecomposition"]: for schema in entry["Schemas"]: entity = schema["Entity"] schema["SelectFromEntity"] = f"{database}.{entity}" @@ -176,14 +176,16 @@ async def fetch_queries_from_cache(self, question: str) -> str: self.schemas[entity] = schema pre_fetched_results_string = "" - if self.pre_run_query_cache and len(cached_schemas) > 0: - logging.info("Cached schemas: %s", cached_schemas) + if self.pre_run_query_cache and len(sql_queries_with_schemas) > 0: + logging.info( + "Cached SQL Queries with Schemas: %s", sql_queries_with_schemas + ) # check the score - if cached_schemas[0]["@search.reranker_score"] > 2.75: + if sql_queries_with_schemas[0]["@search.reranker_score"] > 2.75: logging.info("Score is greater than 3") - sql_queries = cached_schemas[0]["SqlQueryDecomposition"] + sql_queries = sql_queries_with_schemas[0]["SqlQueryDecomposition"] query_result_store = {} query_tasks = [] @@ -208,7 +210,7 @@ async def fetch_queries_from_cache(self, question: str) -> str: return pre_fetched_results_string formatted_sql_cache_string = f"""[BEGIN CACHED QUERIES AND SCHEMAS]:\n{ - json.dumps(cached_schemas, default=str)}[END CACHED QUERIES AND SCHEMAS]""" + json.dumps(sql_queries_with_schemas, default=str)}[END CACHED QUERIES AND SCHEMAS]""" return formatted_sql_cache_string @@ -230,7 +232,9 @@ async def sql_prompt_injection( self.set_mode() if self.use_query_cache: - query_cache_string = await self.fetch_queries_from_cache(question) + query_cache_string = await self.fetch_sql_queries_with_schemas_from_cache( + question + ) else: query_cache_string = None diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 46feb4d0..45762311 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -260,7 +260,7 @@ def handle_node(node): logging.info("SQL Query is valid.") return True - async def fetch_queries_from_cache( + async def fetch_sql_queries_with_schemas_from_cache( self, question: str, injected_parameters: dict = None ) -> str: """Fetch the queries from the cache based on the question. @@ -276,14 +276,14 @@ async def fetch_queries_from_cache( # Return empty results if AI Search is disabled if not self.use_ai_search: return { - "contains_pre_run_results": False, - "cached_questions_and_schemas": None, + "contains_cached_sql_queries_with_schemas_from_cache_database_results": False, + "cached_sql_queries_with_schemas_from_cache": None, } if injected_parameters is None: injected_parameters = {} - cached_schemas = await self.ai_search_connector.run_ai_search_query( + sql_queries_with_schemas = await self.ai_search_connector.run_ai_search_query( question, ["QuestionEmbedding"], ["Question", "SqlQueryDecomposition"], @@ -294,32 +294,30 @@ async def fetch_queries_from_cache( minimum_score=1.5, ) - if len(cached_schemas) == 0: + if len(sql_queries_with_schemas) == 0: return { - "contains_pre_run_results": False, - "cached_questions_and_schemas": None, + "contains_cached_sql_queries_with_schemas_from_cache_database_results": False, + "cached_sql_queries_with_schemas_from_cache": None, } # loop through all sql queries and populate the template in place - for schema in cached_schemas: - sql_queries = schema["SqlQueryDecomposition"] - for sql_query in sql_queries: + for queries_with_schemas in sql_queries_with_schemas: + for sql_query in queries_with_schemas["SqlQueryDecomposition"]: sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render( **injected_parameters ) - logging.info("Cached schemas: %s", cached_schemas) - if self.pre_run_query_cache and len(cached_schemas) > 0: + logging.info("Cached SQL Queries with Schemas: %s", sql_queries_with_schemas) + if self.pre_run_query_cache and len(sql_queries_with_schemas) > 0: # check the score - if cached_schemas[0]["@search.reranker_score"] > 2.75: + if sql_queries_with_schemas[0]["@search.reranker_score"] > 2.75: logging.info("Score is greater than 3") - sql_queries = cached_schemas[0]["SqlQueryDecomposition"] query_result_store = {} query_tasks = [] - for sql_query in sql_queries: + for sql_query in sql_queries_with_schemas[0]["SqlQueryDecomposition"]: logging.info("SQL Query: %s", sql_query) # Run the SQL query @@ -327,18 +325,20 @@ async def fetch_queries_from_cache( sql_results = await asyncio.gather(*query_tasks) - for sql_query, sql_result in zip(sql_queries, sql_results): + for sql_query, sql_result in zip( + sql_queries_with_schemas[0]["SqlQueryDecomposition"], sql_results + ): query_result_store[sql_query["SqlQuery"]] = { "sql_rows": sql_result, "schemas": sql_query["Schemas"], } return { - "contains_pre_run_results": True, - "cached_questions_and_schemas": query_result_store, + "contains_cached_sql_queries_with_schemas_from_cache_database_results": True, + "cached_sql_queries_with_schemas_from_cache": query_result_store, } return { - "contains_pre_run_results": False, - "cached_questions_and_schemas": cached_schemas, + "contains_cached_sql_queries_with_schemas_from_cache_database_results": False, + "cached_sql_queries_with_schemas_from_cache": sql_queries_with_schemas, } diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py index b58f9b5a..d1a236e8 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py @@ -13,26 +13,33 @@ async def process_message( ) -> dict: # Initialize results dictionary cached_results = { - "cached_questions_and_schemas": [], - "contains_pre_run_results": False, + "cached_sql_queries_with_schemas_from_cache": [], + "contains_cached_sql_queries_with_schemas_from_cache_database_results": False, } # Process each question sequentially for message in messages: # Fetch the queries from the cache based on the question logging.info(f"Fetching queries from cache for question: {message}") - cached_query = await self.sql_connector.fetch_queries_from_cache( - message, injected_parameters=injected_parameters + cached_query = ( + await self.sql_connector.fetch_sql_queries_with_schemas_from_cache( + message, injected_parameters=injected_parameters + ) ) # If any question has pre-run results, set the flag - if cached_query.get("contains_pre_run_results", False): - cached_results["contains_pre_run_results"] = True + if cached_query.get( + "contains_cached_sql_queries_with_schemas_from_cache_database_results", + False, + ): + cached_results[ + "contains_cached_sql_queries_with_schemas_from_cache_database_results" + ] = True # Add the cached results for this question - if cached_query.get("cached_questions_and_schemas"): - cached_results["cached_questions_and_schemas"].extend( - cached_query["cached_questions_and_schemas"] + if cached_query.get("cached_sql_queries_with_schemas_from_cache"): + cached_results["cached_sql_queries_with_schemas_from_cache"].extend( + cached_query["cached_sql_queries_with_schemas_from_cache"] ) logging.info(f"Final cached results: {cached_results}") From dac693668463c954726686def320b2a04bea904a Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 16 Jan 2025 17:37:57 +0000 Subject: [PATCH 2/3] Fix poor naming in connectors --- .../src/text_2_sql_core/connectors/factory.py | 4 ++-- .../src/text_2_sql_core/connectors/tsql_sql.py | 2 +- .../src/text_2_sql_core/data_dictionary/cli.py | 16 ++++++++++++++-- .../tsql_data_dictionary_creator.py | 8 ++++---- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py index 21f9941f..a409f5b7 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py @@ -22,9 +22,9 @@ def get_database_connector(): return SnowflakeSqlConnector() elif os.environ["Text2Sql__DatabaseEngine"].upper() == "TSQL": - from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector + from text_2_sql_core.connectors.tsql_sql import TsqlSqlConnector - return TSQLSqlConnector() + return TsqlSqlConnector() elif os.environ["Text2Sql__DatabaseEngine"].upper() == "POSTGRESQL": from text_2_sql_core.connectors.postgresql_sql import ( PostgresqlSqlConnector, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py index 610692cb..3cf6bcd5 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py @@ -10,7 +10,7 @@ from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields -class TSQLSqlConnector(SqlConnector): +class TsqlSqlConnector(SqlConnector): def __init__(self): super().__init__() diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py index 67897a1a..9355bcf1 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py @@ -83,12 +83,24 @@ def create( ) elif engine == DatabaseEngine.TSQL: from text_2_sql_core.data_dictionary.tsql_data_dictionary_creator import ( - TSQLDataDictionaryCreator, + TsqlDataDictionaryCreator, ) - data_dictionary_creator = TSQLDataDictionaryCreator( + data_dictionary_creator = TsqlDataDictionaryCreator( **kwargs, ) + elif engine == DatabaseEngine.POSTGRESQL: + from text_2_sql_core.data_dictionary.postgresql_data_dictionary_creator import ( + PostgresqlDataDictionaryCreator, + ) + + data_dictionary_creator = PostgresqlDataDictionaryCreator( + **kwargs, + ) + else: + raise NotImplementedError( + f"Data Dictionary Creator for {engine.value} is not implemented." + ) except ImportError: detailed_error = f"""Failed to import { engine.value} Data Dictionary Creator. Check you have installed the optional dependencies for this database engine.""" diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/tsql_data_dictionary_creator.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/tsql_data_dictionary_creator.py index 956cba4e..2d360851 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/tsql_data_dictionary_creator.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/tsql_data_dictionary_creator.py @@ -7,10 +7,10 @@ import asyncio import os from text_2_sql_core.utils.database import DatabaseEngine -from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector +from text_2_sql_core.connectors.tsql_sql import TsqlSqlConnector -class TSQLDataDictionaryCreator(DataDictionaryCreator): +class TsqlDataDictionaryCreator(DataDictionaryCreator): def __init__(self, **kwargs): """A method to initialize the DataDictionaryCreator class. @@ -25,7 +25,7 @@ def __init__(self, **kwargs): self.database_engine = DatabaseEngine.TSQL - self.sql_connector = TSQLSqlConnector() + self.sql_connector = TsqlSqlConnector() """A class to extract data dictionary information from a SQL Server database.""" @@ -115,5 +115,5 @@ def extract_entity_relationships_sql_query(self) -> str: if __name__ == "__main__": - data_dictionary_creator = TSQLDataDictionaryCreator() + data_dictionary_creator = TsqlDataDictionaryCreator() asyncio.run(data_dictionary_creator.create_data_dictionary()) From e1dbc13b11194ccf40eaf821ed312fdafc93c65d Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 16 Jan 2025 17:38:48 +0000 Subject: [PATCH 3/3] Update cli --- .../src/text_2_sql_core/data_dictionary/cli.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py index 9355bcf1..586f4660 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py @@ -98,9 +98,10 @@ def create( **kwargs, ) else: - raise NotImplementedError( - f"Data Dictionary Creator for {engine.value} is not implemented." - ) + rich_print("Text2SQL Data Dictionary Creator Failed ❌") + rich_print(f"Database Engine {engine.value} is not supported.") + + raise typer.Exit(code=1) except ImportError: detailed_error = f"""Failed to import { engine.value} Data Dictionary Creator. Check you have installed the optional dependencies for this database engine.""" @@ -112,7 +113,6 @@ def create( try: asyncio.run(data_dictionary_creator.create_data_dictionary()) except Exception as e: - raise e logging.error(e) rich_print("Text2SQL Data Dictionary Creator Failed ❌")