From eb398f835aedf3c16f36cc000492782a748da7e8 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 6 Jan 2025 11:17:21 +0000 Subject: [PATCH 1/3] Update prompts and agents --- .../parallel_query_solving_agent.py | 31 ++++--- .../custom_agents/sql_query_cache_agent.py | 34 ++------ .../sql_schema_selection_agent.py | 84 +++---------------- .../src/text_2_sql_core/connectors/open_ai.py | 37 ++++++-- .../text_2_sql_core/connectors/tsql_sql.py | 1 + .../text_2_sql_core/custom_agents/__init__.py | 0 .../custom_agents/sql_query_cache_agent.py | 39 +++++++++ .../sql_schema_selection_agent.py | 81 ++++++++++++++++++ ...uation_and_sql_query_generation_agent.yaml | 1 + .../src/text_2_sql_core/prompts/load.py | 2 + .../structured_outputs/__init__.py | 0 .../sql_schema_selection_agent.py | 8 ++ 12 files changed, 202 insertions(+), 116 deletions(-) create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/__init__.py create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/sql_schema_selection_agent.py diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index e05cbb00..18292a03 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -3,14 +3,20 @@ from typing import AsyncGenerator, List, Sequence from autogen_agentchat.agents import BaseChatAgent -from autogen_agentchat.base import Response, TaskResult -from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage +from autogen_agentchat.base import Response +from autogen_agentchat.messages import ( + AgentMessage, + ChatMessage, + TextMessage, + ToolCallResultMessage, +) from autogen_core import CancellationToken import json import logging from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql from aiostream import stream from json import JSONDecodeError +import re class ParallelQuerySolvingAgent(BaseChatAgent): @@ -53,9 +59,6 @@ def parse_inner_message(self, message): except JSONDecodeError: pass - # Try to extract JSON from markdown code blocks - import re - json_match = re.search(r"```json\s*(.*?)\s*```", message, re.DOTALL) if json_match: try: @@ -103,12 +106,13 @@ async def consume_inner_messages_from_agentic_flow( logging.info(f"Checking Inner Message: {inner_message}") - if isinstance(inner_message, TaskResult) is False: - try: + try: + if isinstance(inner_message, ToolCallResultMessage): + # Check for SQL query results parsed_message = self.parse_inner_message(inner_message.content) + logging.info(f"Inner Loaded: {parsed_message}") - # Search for specific message types and add them to the final output object if isinstance(parsed_message, dict): if ( "type" in parsed_message @@ -124,6 +128,13 @@ async def consume_inner_messages_from_agentic_flow( } ) + elif isinstance(inner_message, TextMessage): + parsed_message = self.parse_inner_message(inner_message.content) + + logging.info(f"Inner Loaded: {parsed_message}") + + # Search for specific message types and add them to the final output object + if isinstance(parsed_message, dict): if ("contains_pre_run_results" in parsed_message) and ( parsed_message["contains_pre_run_results"] is True ): @@ -139,8 +150,8 @@ async def consume_inner_messages_from_agentic_flow( } ) - except Exception as e: - logging.warning(f"Error processing message: {e}") + except Exception as e: + logging.warning(f"Error processing message: {e}") yield inner_message diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index a7ec5fb4..12994633 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -6,7 +6,9 @@ from autogen_agentchat.base import Response from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage from autogen_core import CancellationToken -from text_2_sql_core.connectors.factory import ConnectorFactory +from text_2_sql_core.custom_agents.sql_query_cache_agent import ( + SqlQueryCacheAgentCustomAgent, +) import json import logging @@ -18,7 +20,7 @@ def __init__(self): "An agent that fetches the queries from the cache based on the user question.", ) - self.sql_connector = ConnectorFactory.get_database_connector() + self.agent = SqlQueryCacheAgentCustomAgent() @property def produced_message_types(self) -> List[type[ChatMessage]]: @@ -49,31 +51,9 @@ async def on_messages_stream( # If not JSON array, process as single question raise ValueError("Could not load message") - # Initialize results dictionary - cached_results = { - "cached_questions_and_schemas": [], - "contains_pre_run_results": False, - } - - # Process each question sequentially - for question in user_questions: - # Fetch the queries from the cache based on the question - logging.info(f"Fetching queries from cache for question: {question}") - cached_query = await self.sql_connector.fetch_queries_from_cache( - question, injected_parameters=injected_parameters - ) - - # If any question has pre-run results, set the flag - if cached_query.get("contains_pre_run_results", False): - cached_results["contains_pre_run_results"] = True - - # Add the cached results for this question - if cached_query.get("cached_questions_and_schemas"): - cached_results["cached_questions_and_schemas"].extend( - cached_query["cached_questions_and_schemas"] - ) - - logging.info(f"Final cached results: {cached_results}") + cached_results = await self.agent.process_message( + user_questions, injected_parameters + ) yield Response( chat_message=TextMessage( content=json.dumps(cached_results), source=self.name diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index b6345388..cd6fe44b 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -6,12 +6,11 @@ from autogen_agentchat.base import Response from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage from autogen_core import CancellationToken -from text_2_sql_core.connectors.factory import ConnectorFactory import json import logging -from text_2_sql_core.prompts.load import load -from jinja2 import Template -import asyncio +from text_2_sql_core.custom_agents.sql_schema_selection_agent import ( + SqlSchemaSelectionAgentCustomAgent, +) class SqlSchemaSelectionAgent(BaseChatAgent): @@ -21,15 +20,7 @@ def __init__(self, **kwargs): "An agent that fetches the schemas from the cache based on the user question.", ) - self.ai_search_connector = ConnectorFactory.get_ai_search_connector() - - self.open_ai_connector = ConnectorFactory.get_open_ai_connector() - - self.sql_connector = ConnectorFactory.get_database_connector() - - system_prompt = load("sql_schema_selection_agent")["system_message"] - - self.system_prompt = Template(system_prompt).render(kwargs) + self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs) @property def produced_message_types(self) -> List[type[ChatMessage]]: @@ -49,64 +40,15 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - last_response = messages[-1].content - - # load the json of the last message and get the user question's - - user_questions = json.loads(last_response) - - logging.info(f"User questions: {user_questions}") - - entity_tasks = [] - - for user_question in user_questions: - messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": user_question}, - ] - entity_tasks.append(self.open_ai_connector.run_completion_request(messages)) - - entity_results = await asyncio.gather(*entity_tasks) - - entity_search_tasks = [] - column_search_tasks = [] - - for entity_result in entity_results: - loaded_entity_result = json.loads(entity_result) - - logging.info(f"Loaded entity result: {loaded_entity_result}") - - for entity_group in loaded_entity_result["entities"]: - entity_search_tasks.append( - self.sql_connector.get_entity_schemas( - " ".join(entity_group), as_json=False - ) - ) - - for filter_condition in loaded_entity_result["filter_conditions"]: - column_search_tasks.append( - self.ai_search_connector.get_column_values( - filter_condition, as_json=False - ) - ) - - schemas_results = await asyncio.gather(*entity_search_tasks) - column_value_results = await asyncio.gather(*column_search_tasks) - - # deduplicate schemas - final_schemas = [] - - for schema_result in schemas_results: - for schema in schema_result: - if schema not in final_schemas: - final_schemas.append(schema) - - final_results = { - "COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results, - "SCHEMA_OPTIONS": final_schemas, - } - - logging.info(f"Final results: {final_results}") + try: + request_details = json.loads(messages[0].content) + user_questions = request_details["question"] + logging.info(f"Processing questions: {user_questions}") + except json.JSONDecodeError: + # If not JSON array, process as single question + raise ValueError("Could not load message") + + final_results = await self.agent.process_message(user_questions) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py index f57361f2..707daae6 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py @@ -28,7 +28,12 @@ def get_authentication_properties(cls) -> dict: return token_provider, api_key async def run_completion_request( - self, messages: list[dict], temperature=0, max_tokens=2000, model="4o-mini" + self, + messages: list[dict], + temperature=0, + max_tokens=2000, + model="4o-mini", + response_format=None, ) -> str: if model == "4o-mini": model_deployment = os.environ["OpenAI__MiniCompletionDeployment"] @@ -45,13 +50,29 @@ async def run_completion_request( azure_ad_token_provider=token_provider, api_key=api_key, ) as open_ai_client: - response = await open_ai_client.chat.completions.create( - model=model_deployment, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - ) - return response.choices[0].message.content + if response_format is not None: + response = await open_ai_client.beta.chat.completions.parse( + model=model_deployment, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format=response_format, + ) + else: + response = await open_ai_client.chat.completions.create( + model=model_deployment, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + message = response.choices[0].message + if response_format is not None and message.parsed is not None: + return message.parsed + elif response_format is not None: + return message.refusal + else: + return message.content async def run_embedding_request(self, batch: list[str]): token_provider, api_key = self.get_authentication_properties() diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py index 6fb75011..875563dc 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py @@ -127,6 +127,7 @@ async def get_entity_schemas( del schema["Entity"] del schema["Schema"] + del schema["Database"] if as_json: return json.dumps(schemas, default=str) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py new file mode 100644 index 00000000..958d1a45 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from text_2_sql_core.connectors.factory import ConnectorFactory +import logging + + +class SqlQueryCacheAgentCustomAgent: + def __init__(self): + self.sql_connector = ConnectorFactory.get_database_connector() + + async def process_message( + self, user_questions: list[str], injected_parameters: dict + ) -> dict: + # Initialize results dictionary + cached_results = { + "cached_questions_and_schemas": [], + "contains_pre_run_results": False, + } + + # Process each question sequentially + for question in user_questions: + # Fetch the queries from the cache based on the question + logging.info(f"Fetching queries from cache for question: {question}") + cached_query = await self.sql_connector.fetch_queries_from_cache( + question, injected_parameters=injected_parameters + ) + + # If any question has pre-run results, set the flag + if cached_query.get("contains_pre_run_results", False): + cached_results["contains_pre_run_results"] = True + + # Add the cached results for this question + if cached_query.get("cached_questions_and_schemas"): + cached_results["cached_questions_and_schemas"].extend( + cached_query["cached_questions_and_schemas"] + ) + + logging.info(f"Final cached results: {cached_results}") + return cached_results diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py new file mode 100644 index 00000000..20ac7d4b --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from text_2_sql_core.connectors.factory import ConnectorFactory +import logging +from text_2_sql_core.prompts.load import load +from jinja2 import Template +import asyncio +from text_2_sql_core.structured_outputs.sql_schema_selection_agent import ( + SQLSchemaSelectionAgentOutput, +) + + +class SqlSchemaSelectionAgentCustomAgent: + def __init__(self, **kwargs): + self.ai_search_connector = ConnectorFactory.get_ai_search_connector() + + self.open_ai_connector = ConnectorFactory.get_open_ai_connector() + + self.sql_connector = ConnectorFactory.get_database_connector() + + system_prompt = load("sql_schema_selection_agent")["system_message"] + + self.system_prompt = Template(system_prompt).render(kwargs) + + async def process_message(self, user_questions: list[str]) -> dict: + logging.info(f"User questions: {user_questions}") + + entity_tasks = [] + + for user_question in user_questions: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_question}, + ] + entity_tasks.append( + self.open_ai_connector.run_completion_request( + messages, response_format=SQLSchemaSelectionAgentOutput + ) + ) + + entity_results = await asyncio.gather(*entity_tasks) + + entity_search_tasks = [] + column_search_tasks = [] + + for entity_result in entity_results: + logging.info(f"Entity result: {entity_result}") + + for entity_group in entity_result.entities: + entity_search_tasks.append( + self.sql_connector.get_entity_schemas( + " ".join(entity_group), as_json=False + ) + ) + + for filter_condition in entity_result.filter_conditions: + column_search_tasks.append( + self.ai_search_connector.get_column_values( + filter_condition, as_json=False + ) + ) + + schemas_results = await asyncio.gather(*entity_search_tasks) + column_value_results = await asyncio.gather(*column_search_tasks) + + # deduplicate schemas + final_schemas = [] + + for schema_result in schemas_results: + for schema in schema_result: + if schema not in final_schemas: + final_schemas.append(schema) + + final_results = { + "COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results, + "SCHEMA_OPTIONS": final_schemas, + } + + logging.info(f"Final results: {final_results}") + + return final_results diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml index c8fe5b84..32ab8971 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml @@ -154,6 +154,7 @@ system_message: * Detailed syntax corrections * Query execution * Result formatting + - For a given entity, use the 'SelectFromEntity' property in the SELECT FROM part of the SQL query. If the property is {'SelectFromEntity': 'test_schema.test_table'}, the select statement will be formulated from 'SELECT FROM test_schema.test_table WHERE . Remember: Your job is to focus on the data relationships and logic while following basic {{ target_engine }} patterns. diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py index 313954fa..482e140f 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import yaml import os diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/sql_schema_selection_agent.py new file mode 100644 index 00000000..dd5e6c67 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/sql_schema_selection_agent.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from pydantic import BaseModel + + +class SQLSchemaSelectionAgentOutput(BaseModel): + entities: list[list[str]] + filter_conditions: list[str] From 4e00504dccebb50c9ce084c70003ce0b75ed0d1c Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 6 Jan 2025 12:40:23 +0000 Subject: [PATCH 2/3] Map engine specific deals --- .../src/text_2_sql_core/connectors/ai_search.py | 6 ++++-- .../src/text_2_sql_core/connectors/databricks_sql.py | 2 +- .../src/text_2_sql_core/connectors/postgresql_sql.py | 3 ++- .../src/text_2_sql_core/connectors/snowflake_sql.py | 2 +- .../src/text_2_sql_core/connectors/tsql_sql.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index b3c03b7b..22d43de6 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -13,6 +13,8 @@ from typing import Annotated from text_2_sql_core.connectors.open_ai import OpenAIConnector +from text_2_sql_core.utils.database import DatabaseEngineSpecificFields + class AISearchConnector: def __init__(self): @@ -167,7 +169,7 @@ async def get_entity_schemas( "The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.", ] = [], engine_specific_fields: Annotated[ - list[str], + list[DatabaseEngineSpecificFields], "The fields specific to the engine to be included in the search results.", ] = [], ) -> str: @@ -192,7 +194,7 @@ async def get_entity_schemas( "Columns", "EntityRelationships", "CompleteEntityRelationshipsGraph", - ] + engine_specific_fields + ] + list(map(str, engine_specific_fields)) schemas = await self.run_ai_search_query( text, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index 6271c12d..c72322c8 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -135,7 +135,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities, engine_specific_fields=["Catalog"] + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py index 522554b0..f1060417 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py @@ -111,7 +111,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: @@ -119,6 +119,7 @@ async def get_entity_schemas( del schema["Entity"] del schema["Schema"] + del schema["Database"] if as_json: return json.dumps(schemas, default=str) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py index eb8d4b84..e71cd013 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py @@ -150,7 +150,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities, engine_specific_fields=["Warehouse", "Database"] + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py index 425750f9..36c1f83c 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py @@ -124,7 +124,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: From e1a1e03369acf816ea17843d3d713427c67b7370 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 6 Jan 2025 13:42:19 +0000 Subject: [PATCH 3/3] Fix agents --- .../creators/llm_agent_creator.py | 7 ++-- .../parallel_query_solving_agent.py | 42 ++++++++++--------- .../text_2_sql_core/connectors/ai_search.py | 41 +++++++----------- .../src/text_2_sql_core/connectors/sql.py | 42 +++++++++++++++++++ .../sql_schema_selection_agent.py | 2 +- 5 files changed, 84 insertions(+), 50 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index 6089d89d..d2a2039a 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -24,7 +24,7 @@ def load_agent_file(cls, name: str) -> dict: return load(name.lower()) @classmethod - def get_tool(cls, sql_helper, ai_search_helper, tool_name: str): + def get_tool(cls, sql_helper, tool_name: str): """Gets the tool based on the tool name. Args: ---- @@ -46,7 +46,7 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str): ) elif tool_name == "sql_get_column_values_tool": return FunctionToolAlias( - ai_search_helper.get_column_values, + sql_helper.get_column_values, description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.", ) else: @@ -88,12 +88,11 @@ def create(cls, name: str, **kwargs) -> AssistantAgent: agent_file = cls.load_agent_file(name) sql_helper = ConnectorFactory.get_database_connector() - ai_search_helper = ConnectorFactory.get_ai_search_connector() tools = [] if "tools" in agent_file and len(agent_file["tools"]) > 0: for tool in agent_file["tools"]: - tools.append(cls.get_tool(sql_helper, ai_search_helper, tool)) + tools.append(cls.get_tool(sql_helper, tool)) agent = AssistantAgent( name=name, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 18292a03..6abf164e 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -108,25 +108,28 @@ async def consume_inner_messages_from_agentic_flow( try: if isinstance(inner_message, ToolCallResultMessage): - # Check for SQL query results - parsed_message = self.parse_inner_message(inner_message.content) - - logging.info(f"Inner Loaded: {parsed_message}") - - if isinstance(parsed_message, dict): - if ( - "type" in parsed_message - and parsed_message["type"] - == "query_execution_with_limit" - ): - database_results[identifier].append( - { - "sql_query": parsed_message[ - "sql_query" - ].replace("\n", " "), - "sql_rows": parsed_message["sql_rows"], - } - ) + for call_result in inner_message.content: + # Check for SQL query results + parsed_message = self.parse_inner_message( + call_result.content + ) + logging.info(f"Inner Loaded: {parsed_message}") + + if isinstance(parsed_message, dict): + if ( + "type" in parsed_message + and parsed_message["type"] + == "query_execution_with_limit" + ): + logging.info("Contains query results") + database_results[identifier].append( + { + "sql_query": parsed_message[ + "sql_query" + ].replace("\n", " "), + "sql_rows": parsed_message["sql_rows"], + } + ) elif isinstance(inner_message, TextMessage): parsed_message = self.parse_inner_message(inner_message.content) @@ -138,6 +141,7 @@ async def consume_inner_messages_from_agentic_flow( if ("contains_pre_run_results" in parsed_message) and ( parsed_message["contains_pre_run_results"] is True ): + logging.info("Contains pre-run results") for pre_run_sql_query, pre_run_result in parsed_message[ "cached_questions_and_schemas" ].items(): diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index 22d43de6..7eb26f46 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -9,7 +9,6 @@ import logging import base64 from datetime import datetime, timezone -import json from typing import Annotated from text_2_sql_core.connectors.open_ai import OpenAIConnector @@ -111,7 +110,6 @@ async def get_column_values( str, "The text to run a semantic search against. Relevant entities will be returned.", ], - as_json: bool = True, ): """Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. @@ -139,24 +137,7 @@ async def get_column_values( minimum_score=5, ) - # build into a common format - column_values = {} - - for value in values: - trimmed_fqn = ".".join(value["FQN"].split(".")[:-1]) - if trimmed_fqn not in column_values: - column_values[trimmed_fqn] = [] - - column_values[trimmed_fqn].append(value["Value"]) - - logging.info("Column Values: %s", column_values) - - filter_to_column = {text: column_values} - - if as_json: - return json.dumps(filter_to_column, default=str) - else: - return filter_to_column + return values async def get_entity_schemas( self, @@ -185,6 +166,8 @@ async def get_entity_schemas( logging.info("Search Text: %s", text) + stringified_engine_specific_fields = list(map(str, engine_specific_fields)) + retrieval_fields = [ "FQN", "Entity", @@ -194,7 +177,7 @@ async def get_entity_schemas( "Columns", "EntityRelationships", "CompleteEntityRelationshipsGraph", - ] + list(map(str, engine_specific_fields)) + ] + stringified_engine_specific_fields schemas = await self.run_ai_search_query( text, @@ -207,6 +190,8 @@ async def get_entity_schemas( top=3, ) + fqn_to_trim = ".".join(stringified_engine_specific_fields) + if len(excluded_entities) == 0: return schemas @@ -220,12 +205,16 @@ async def get_entity_schemas( and len(schema["CompleteEntityRelationshipsGraph"]) == 0 ): del schema["CompleteEntityRelationshipsGraph"] + else: + schema["CompleteEntityRelationshipsGraph"] = list( + map( + lambda x: x.replace(fqn_to_trim, ""), + schema["CompleteEntityRelationshipsGraph"], + ) + ) - if ( - schema["SammpleValues"] is not None - and len(schema["SammpleValues"]) == 0 - ): - del schema["SammpleValues"] + if schema["SampleValues"] is not None and len(schema["SampleValues"]) == 0: + del schema["SampleValues"] if ( schema["EntityRelationships"] is not None diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 21a93780..87cd30a0 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -74,6 +74,48 @@ async def query_execution( list[dict]: The results of the SQL query. """ + async def get_column_values( + self, + text: Annotated[ + str, + "The text to run a semantic search against. Relevant entities will be returned.", + ], + as_json: bool = True, + ): + """Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. + + Args: + ---- + text (str): The text to run the search against. + + Returns: + ------- + str: The values of the column in JSON format. + """ + + values = await self.ai_search_connector.get_column_values(text) + + # build into a common format + column_values = {} + + starting = len(self.engine_specific_fields) + + for value in values: + trimmed_fqn = ".".join(value["FQN"].split(".")[starting:-1]) + if trimmed_fqn not in column_values: + column_values[trimmed_fqn] = [] + + column_values[trimmed_fqn].append(value["Value"]) + + logging.info("Column Values: %s", column_values) + + filter_to_column = {text: column_values} + + if as_json: + return json.dumps(filter_to_column, default=str) + else: + return filter_to_column + @abstractmethod async def get_entity_schemas( self, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py index 20ac7d4b..b02c5e66 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py @@ -55,7 +55,7 @@ async def process_message(self, user_questions: list[str]) -> dict: for filter_condition in entity_result.filter_conditions: column_search_tasks.append( - self.ai_search_connector.get_column_values( + self.sql_connector.get_column_values( filter_condition, as_json=False ) )