From 4884a7d1d46a55994756321e04123266198454f0 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Tue, 7 Jan 2025 18:42:08 +0000 Subject: [PATCH 1/3] Add engine specific rules --- .../src/text_2_sql_core/connectors/databricks_sql.py | 5 +++++ .../src/text_2_sql_core/connectors/snowflake_sql.py | 5 +++++ .../text_2_sql_core/src/text_2_sql_core/connectors/sql.py | 6 ++++++ .../src/text_2_sql_core/connectors/tsql_sql.py | 5 +++++ 4 files changed, 21 insertions(+) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index c72322c..04b7327 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -17,6 +17,11 @@ def __init__(self): self.database_engine = DatabaseEngine.DATABRICKS + @property + def engine_specific_rules(self) -> str: + """Get the engine specific rules.""" + return + @property def engine_specific_fields(self) -> list[str]: """Get the engine specific fields.""" diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py index e71cd01..5f40627 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py @@ -17,6 +17,11 @@ def __init__(self): self.database_engine = DatabaseEngine.SNOWFLAKE + @property + def engine_specific_rules(self) -> str: + """Get the engine specific rules.""" + return """When an ORDER BY clause is included in the SQL query, always append the ORDER BY clause with 'NULLS LAST' to ensure that NULL values are at the end of the result set. e.g. 'ORDER BY column_name DESC NULLS LAST'.""" + @property def engine_specific_fields(self) -> list[str]: """Get the engine specific fields.""" diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 87cd30a..69c9801 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -31,6 +31,12 @@ def __init__(self): self.database_engine = None + @property + @abstractmethod + def engine_specific_rules(self) -> str: + """Get the engine specific rules.""" + pass + @property @abstractmethod def invalid_identifiers(self) -> list[str]: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py index 36c1f83..610692c 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py @@ -16,6 +16,11 @@ def __init__(self): self.database_engine = DatabaseEngine.TSQL + @property + def engine_specific_rules(self) -> str: + """Get the engine specific rules.""" + return """Use TOP X instead of LIMIT X to limit the number of rows returned.""" + @property def engine_specific_fields(self) -> list[str]: """Get the engine specific fields.""" From 6bd783742b367af9d85efe7c089ef75eaf689b2f Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Tue, 7 Jan 2025 18:46:02 +0000 Subject: [PATCH 2/3] Update engine specific rules logix --- .../src/autogen_text_2_sql/autogen_text_2_sql.py | 7 ++----- .../autogen_text_2_sql/creators/llm_agent_creator.py | 12 ++++++++++++ .../custom_agents/parallel_query_solving_agent.py | 7 ++----- .../autogen_text_2_sql/inner_autogen_text_2_sql.py | 6 +----- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index ca63dce..82db5ee 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -31,9 +31,8 @@ class AutoGenText2Sql: - def __init__(self, engine_specific_rules: str, **kwargs: dict): + def __init__(self, **kwargs: dict): self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() - self.engine_specific_rules = engine_specific_rules self.kwargs = kwargs def get_all_agents(self): @@ -45,9 +44,7 @@ def get_all_agents(self): "question_rewrite_agent", current_datetime=current_datetime ) - self.parallel_query_solving_agent = ParallelQuerySolvingAgent( - engine_specific_rules=self.engine_specific_rules, **self.kwargs - ) + self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs) self.answer_agent = LLMAgentCreator.create("answer_agent") diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index d2a2039..5a26c7a 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -6,6 +6,7 @@ from text_2_sql_core.prompts.load import load from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator from jinja2 import Template +import logging class LLMAgentCreator: @@ -89,6 +90,17 @@ def create(cls, name: str, **kwargs) -> AssistantAgent: sql_helper = ConnectorFactory.get_database_connector() + # Handle engine specific rules + if "engine_specific_rules" not in kwargs: + if sql_helper.engine_specific_rules is not None: + kwargs["engine_specific_fields"] = sql_helper.engine_specific_rules + logging.info( + "Engine specific fields pulled from in-built: %s", + kwargs["engine_specific_fields"], + ) + else: + kwargs["engine_specific_fields"] = "" + tools = [] if "tools" in agent_file and len(agent_file["tools"]) > 0: for tool in agent_file["tools"]: diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 53c1b86..05bfde0 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -20,13 +20,12 @@ class ParallelQuerySolvingAgent(BaseChatAgent): - def __init__(self, engine_specific_rules: str, **kwargs: dict): + def __init__(self, **kwargs: dict): super().__init__( "parallel_query_solving_agent", "An agent that solves each query in parallel.", ) - self.engine_specific_rules = engine_specific_rules self.kwargs = kwargs @property @@ -177,9 +176,7 @@ async def consume_inner_messages_from_agentic_flow( for question_rewrite in question_rewrites["sub_questions"]: logging.info(f"Processing sub-query: {question_rewrite}") # Create an instance of the InnerAutoGenText2Sql class - inner_autogen_text_2_sql = InnerAutoGenText2Sql( - self.engine_specific_rules, **self.kwargs - ) + inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs) identifier = ", ".join(question_rewrite) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 34328a2..064c777 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -38,10 +38,9 @@ async def on_messages_stream(self, messages, sender=None, config=None): class InnerAutoGenText2Sql: - def __init__(self, engine_specific_rules: str, **kwargs: dict): + def __init__(self, **kwargs: dict): self.pre_run_query_cache = False self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() - self.engine_specific_rules = engine_specific_rules self.kwargs = kwargs self.set_mode() @@ -73,21 +72,18 @@ def get_all_agents(self): self.sql_schema_selection_agent = SqlSchemaSelectionAgent( target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) self.sql_query_correction_agent = LLMAgentCreator.create( "sql_query_correction_agent", target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create( "disambiguation_and_sql_query_generation_agent", target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) agents = [ From 75d438f68692fe44f85049e874b32d083f7698d0 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Tue, 7 Jan 2025 18:46:37 +0000 Subject: [PATCH 3/3] Remove engine specific --- .../Iteration 5 - Agentic Vector Based Text2SQL.ipynb | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 23b2a91..05ce703 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -84,7 +84,7 @@ "metadata": {}, "outputs": [], "source": [ - "agentic_text_2_sql = AutoGenText2Sql(engine_specific_rules=\"\", use_case=\"Analysing sales data\")" + "agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")" ] }, { @@ -100,9 +100,16 @@ "metadata": {}, "outputs": [], "source": [ - "async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What total number of orders in June 2008?\")):\n", + "async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n", " logging.info(\"Received %s Message from Text2SQL System\", message)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {