Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion text_2_sql/autogen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Each agent can be configured with specific parameters and prompts to optimize it

## Query Cache Implementation Details

The vector based with query cache uses the `fetch_queries_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.
The vector based with query cache uses the `fetch_sql_queries_with_schemas_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.

If the score of the top result is higher than the defined threshold, the query will be executed against the target data source and the results included in the prompt. This allows us to prompt the LLM to evaluated whether it can use these results to answer the question, **without further SQL Query generation** to speed up the process.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,14 @@ async def consume_inner_messages_from_agentic_flow(
# Search for specific message types and add them to the final output object
if isinstance(parsed_message, dict):
# Check if the message contains pre-run results
if ("contains_pre_run_results" in parsed_message) and (
parsed_message["contains_pre_run_results"] is True
if (
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
in parsed_message
) and (
parsed_message[
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
]
is True
):
logging.info("Contains pre-run results")
for pre_run_sql_query, pre_run_result in parsed_message[
Expand Down
4 changes: 2 additions & 2 deletions text_2_sql/previous_iterations/semantic_kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ This method is called by the Semantic Kernel framework automatically, when instr

The search text passed is vectorised against the entity level **Description** columns. A hybrid Semantic Reranking search is applied against the **EntityName**, **Entity**, **Columns/Name** fields.

#### fetch_queries_from_cache()
#### fetch_sql_queries_with_schemas_from_cache()

The vector based with query cache uses the `fetch_queries_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.
The vector based with query cache uses the `fetch_sql_queries_with_schemas_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.

If the score of the top result is higher than the defined threshold, the query will be executed against the target data source and the results included in the prompt. This allows us to prompt the LLM to evaluated whether it can use these results to answer the question, **without further SQL Query generation** to speed up the process.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def fetch_schemas_from_store(self, search: str) -> list[dict]:

return schemas

async def fetch_queries_from_cache(self, question: str) -> str:
async def fetch_sql_queries_with_schemas_from_cache(self, question: str) -> str:
"""Fetch the queries from the cache based on the question.

Args:
Expand All @@ -151,7 +151,7 @@ async def fetch_queries_from_cache(self, question: str) -> str:
if not self.use_query_cache:
return None

cached_schemas = await self.ai_search.run_ai_search_query(
sql_queries_with_schemas = await self.ai_search.run_ai_search_query(
question,
["QuestionEmbedding"],
["Question", "SqlQueryDecomposition"],
Expand All @@ -164,26 +164,28 @@ async def fetch_queries_from_cache(self, question: str) -> str:
minimum_score=1.5,
)

if len(cached_schemas) == 0:
if len(sql_queries_with_schemas) == 0:
return None
else:
database = os.environ["Text2Sql__DatabaseName"]
for entry in cached_schemas["SqlQueryDecomposition"]:
for entry in sql_queries_with_schemas["SqlQueryDecomposition"]:
for schema in entry["Schemas"]:
entity = schema["Entity"]
schema["SelectFromEntity"] = f"{database}.{entity}"

self.schemas[entity] = schema

pre_fetched_results_string = ""
if self.pre_run_query_cache and len(cached_schemas) > 0:
logging.info("Cached schemas: %s", cached_schemas)
if self.pre_run_query_cache and len(sql_queries_with_schemas) > 0:
logging.info(
"Cached SQL Queries with Schemas: %s", sql_queries_with_schemas
)

# check the score
if cached_schemas[0]["@search.reranker_score"] > 2.75:
if sql_queries_with_schemas[0]["@search.reranker_score"] > 2.75:
logging.info("Score is greater than 3")

sql_queries = cached_schemas[0]["SqlQueryDecomposition"]
sql_queries = sql_queries_with_schemas[0]["SqlQueryDecomposition"]
query_result_store = {}

query_tasks = []
Expand All @@ -208,7 +210,7 @@ async def fetch_queries_from_cache(self, question: str) -> str:
return pre_fetched_results_string

formatted_sql_cache_string = f"""[BEGIN CACHED QUERIES AND SCHEMAS]:\n{
json.dumps(cached_schemas, default=str)}[END CACHED QUERIES AND SCHEMAS]"""
json.dumps(sql_queries_with_schemas, default=str)}[END CACHED QUERIES AND SCHEMAS]"""

return formatted_sql_cache_string

Expand All @@ -230,7 +232,9 @@ async def sql_prompt_injection(
self.set_mode()

if self.use_query_cache:
query_cache_string = await self.fetch_queries_from_cache(question)
query_cache_string = await self.fetch_sql_queries_with_schemas_from_cache(
question
)
else:
query_cache_string = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def get_database_connector():

return SnowflakeSqlConnector()
elif os.environ["Text2Sql__DatabaseEngine"].upper() == "TSQL":
from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector
from text_2_sql_core.connectors.tsql_sql import TsqlSqlConnector

return TSQLSqlConnector()
return TsqlSqlConnector()
elif os.environ["Text2Sql__DatabaseEngine"].upper() == "POSTGRESQL":
from text_2_sql_core.connectors.postgresql_sql import (
PostgresqlSqlConnector,
Expand Down
40 changes: 20 additions & 20 deletions text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def handle_node(node):
logging.info("SQL Query is valid.")
return True

async def fetch_queries_from_cache(
async def fetch_sql_queries_with_schemas_from_cache(
self, question: str, injected_parameters: dict = None
) -> str:
"""Fetch the queries from the cache based on the question.
Expand All @@ -276,14 +276,14 @@ async def fetch_queries_from_cache(
# Return empty results if AI Search is disabled
if not self.use_ai_search:
return {
"contains_pre_run_results": False,
"cached_questions_and_schemas": None,
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
"cached_sql_queries_with_schemas_from_cache": None,
}

if injected_parameters is None:
injected_parameters = {}

cached_schemas = await self.ai_search_connector.run_ai_search_query(
sql_queries_with_schemas = await self.ai_search_connector.run_ai_search_query(
question,
["QuestionEmbedding"],
["Question", "SqlQueryDecomposition"],
Expand All @@ -294,51 +294,51 @@ async def fetch_queries_from_cache(
minimum_score=1.5,
)

if len(cached_schemas) == 0:
if len(sql_queries_with_schemas) == 0:
return {
"contains_pre_run_results": False,
"cached_questions_and_schemas": None,
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
"cached_sql_queries_with_schemas_from_cache": None,
}

# loop through all sql queries and populate the template in place
for schema in cached_schemas:
sql_queries = schema["SqlQueryDecomposition"]
for sql_query in sql_queries:
for queries_with_schemas in sql_queries_with_schemas:
for sql_query in queries_with_schemas["SqlQueryDecomposition"]:
sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render(
**injected_parameters
)

logging.info("Cached schemas: %s", cached_schemas)
if self.pre_run_query_cache and len(cached_schemas) > 0:
logging.info("Cached SQL Queries with Schemas: %s", sql_queries_with_schemas)
if self.pre_run_query_cache and len(sql_queries_with_schemas) > 0:
# check the score
if cached_schemas[0]["@search.reranker_score"] > 2.75:
if sql_queries_with_schemas[0]["@search.reranker_score"] > 2.75:
logging.info("Score is greater than 3")

sql_queries = cached_schemas[0]["SqlQueryDecomposition"]
query_result_store = {}

query_tasks = []

for sql_query in sql_queries:
for sql_query in sql_queries_with_schemas[0]["SqlQueryDecomposition"]:
logging.info("SQL Query: %s", sql_query)

# Run the SQL query
query_tasks.append(self.query_execution(sql_query["SqlQuery"]))

sql_results = await asyncio.gather(*query_tasks)

for sql_query, sql_result in zip(sql_queries, sql_results):
for sql_query, sql_result in zip(
sql_queries_with_schemas[0]["SqlQueryDecomposition"], sql_results
):
query_result_store[sql_query["SqlQuery"]] = {
"sql_rows": sql_result,
"schemas": sql_query["Schemas"],
}

return {
"contains_pre_run_results": True,
"cached_questions_and_schemas": query_result_store,
"contains_cached_sql_queries_with_schemas_from_cache_database_results": True,
"cached_sql_queries_with_schemas_from_cache": query_result_store,
}

return {
"contains_pre_run_results": False,
"cached_questions_and_schemas": cached_schemas,
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
"cached_sql_queries_with_schemas_from_cache": sql_queries_with_schemas,
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields


class TSQLSqlConnector(SqlConnector):
class TsqlSqlConnector(SqlConnector):
def __init__(self):
super().__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,33 @@ async def process_message(
) -> dict:
# Initialize results dictionary
cached_results = {
"cached_questions_and_schemas": [],
"contains_pre_run_results": False,
"cached_sql_queries_with_schemas_from_cache": [],
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
}

# Process each question sequentially
for message in messages:
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {message}")
cached_query = await self.sql_connector.fetch_queries_from_cache(
message, injected_parameters=injected_parameters
cached_query = (
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
message, injected_parameters=injected_parameters
)
)

# If any question has pre-run results, set the flag
if cached_query.get("contains_pre_run_results", False):
cached_results["contains_pre_run_results"] = True
if cached_query.get(
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
False,
):
cached_results[
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
] = True

# Add the cached results for this question
if cached_query.get("cached_questions_and_schemas"):
cached_results["cached_questions_and_schemas"].extend(
cached_query["cached_questions_and_schemas"]
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
cached_query["cached_sql_queries_with_schemas_from_cache"]
)

logging.info(f"Final cached results: {cached_results}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,25 @@ def create(
)
elif engine == DatabaseEngine.TSQL:
from text_2_sql_core.data_dictionary.tsql_data_dictionary_creator import (
TSQLDataDictionaryCreator,
TsqlDataDictionaryCreator,
)

data_dictionary_creator = TSQLDataDictionaryCreator(
data_dictionary_creator = TsqlDataDictionaryCreator(
**kwargs,
)
elif engine == DatabaseEngine.POSTGRESQL:
from text_2_sql_core.data_dictionary.postgresql_data_dictionary_creator import (
PostgresqlDataDictionaryCreator,
)

data_dictionary_creator = PostgresqlDataDictionaryCreator(
**kwargs,
)
else:
rich_print("Text2SQL Data Dictionary Creator Failed ❌")
rich_print(f"Database Engine {engine.value} is not supported.")

raise typer.Exit(code=1)
except ImportError:
detailed_error = f"""Failed to import {
engine.value} Data Dictionary Creator. Check you have installed the optional dependencies for this database engine."""
Expand All @@ -100,7 +113,6 @@ def create(
try:
asyncio.run(data_dictionary_creator.create_data_dictionary())
except Exception as e:
raise e
logging.error(e)
rich_print("Text2SQL Data Dictionary Creator Failed ❌")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import asyncio
import os
from text_2_sql_core.utils.database import DatabaseEngine
from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector
from text_2_sql_core.connectors.tsql_sql import TsqlSqlConnector


class TSQLDataDictionaryCreator(DataDictionaryCreator):
class TsqlDataDictionaryCreator(DataDictionaryCreator):
def __init__(self, **kwargs):
"""A method to initialize the DataDictionaryCreator class.

Expand All @@ -25,7 +25,7 @@ def __init__(self, **kwargs):

self.database_engine = DatabaseEngine.TSQL

self.sql_connector = TSQLSqlConnector()
self.sql_connector = TsqlSqlConnector()

"""A class to extract data dictionary information from a SQL Server database."""

Expand Down Expand Up @@ -115,5 +115,5 @@ def extract_entity_relationships_sql_query(self) -> str:


if __name__ == "__main__":
data_dictionary_creator = TSQLDataDictionaryCreator()
data_dictionary_creator = TsqlDataDictionaryCreator()
asyncio.run(data_dictionary_creator.create_data_dictionary())
Loading