diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index 6089d89..d2a2039 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -24,7 +24,7 @@ def load_agent_file(cls, name: str) -> dict: return load(name.lower()) @classmethod - def get_tool(cls, sql_helper, ai_search_helper, tool_name: str): + def get_tool(cls, sql_helper, tool_name: str): """Gets the tool based on the tool name. Args: ---- @@ -46,7 +46,7 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str): ) elif tool_name == "sql_get_column_values_tool": return FunctionToolAlias( - ai_search_helper.get_column_values, + sql_helper.get_column_values, description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.", ) else: @@ -88,12 +88,11 @@ def create(cls, name: str, **kwargs) -> AssistantAgent: agent_file = cls.load_agent_file(name) sql_helper = ConnectorFactory.get_database_connector() - ai_search_helper = ConnectorFactory.get_ai_search_connector() tools = [] if "tools" in agent_file and len(agent_file["tools"]) > 0: for tool in agent_file["tools"]: - tools.append(cls.get_tool(sql_helper, ai_search_helper, tool)) + tools.append(cls.get_tool(sql_helper, tool)) agent = AssistantAgent( name=name, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index e05cbb0..6abf164 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -3,14 +3,20 @@ from typing import AsyncGenerator, List, Sequence from autogen_agentchat.agents import BaseChatAgent -from autogen_agentchat.base import Response, TaskResult -from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage +from autogen_agentchat.base import Response +from autogen_agentchat.messages import ( + AgentMessage, + ChatMessage, + TextMessage, + ToolCallResultMessage, +) from autogen_core import CancellationToken import json import logging from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql from aiostream import stream from json import JSONDecodeError +import re class ParallelQuerySolvingAgent(BaseChatAgent): @@ -53,9 +59,6 @@ def parse_inner_message(self, message): except JSONDecodeError: pass - # Try to extract JSON from markdown code blocks - import re - json_match = re.search(r"```json\s*(.*?)\s*```", message, re.DOTALL) if json_match: try: @@ -103,30 +106,42 @@ async def consume_inner_messages_from_agentic_flow( logging.info(f"Checking Inner Message: {inner_message}") - if isinstance(inner_message, TaskResult) is False: - try: + try: + if isinstance(inner_message, ToolCallResultMessage): + for call_result in inner_message.content: + # Check for SQL query results + parsed_message = self.parse_inner_message( + call_result.content + ) + logging.info(f"Inner Loaded: {parsed_message}") + + if isinstance(parsed_message, dict): + if ( + "type" in parsed_message + and parsed_message["type"] + == "query_execution_with_limit" + ): + logging.info("Contains query results") + database_results[identifier].append( + { + "sql_query": parsed_message[ + "sql_query" + ].replace("\n", " "), + "sql_rows": parsed_message["sql_rows"], + } + ) + + elif isinstance(inner_message, TextMessage): parsed_message = self.parse_inner_message(inner_message.content) + logging.info(f"Inner Loaded: {parsed_message}") # Search for specific message types and add them to the final output object if isinstance(parsed_message, dict): - if ( - "type" in parsed_message - and parsed_message["type"] - == "query_execution_with_limit" - ): - database_results[identifier].append( - { - "sql_query": parsed_message[ - "sql_query" - ].replace("\n", " "), - "sql_rows": parsed_message["sql_rows"], - } - ) - if ("contains_pre_run_results" in parsed_message) and ( parsed_message["contains_pre_run_results"] is True ): + logging.info("Contains pre-run results") for pre_run_sql_query, pre_run_result in parsed_message[ "cached_questions_and_schemas" ].items(): @@ -139,8 +154,8 @@ async def consume_inner_messages_from_agentic_flow( } ) - except Exception as e: - logging.warning(f"Error processing message: {e}") + except Exception as e: + logging.warning(f"Error processing message: {e}") yield inner_message diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index a7ec5fb..1299463 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -6,7 +6,9 @@ from autogen_agentchat.base import Response from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage from autogen_core import CancellationToken -from text_2_sql_core.connectors.factory import ConnectorFactory +from text_2_sql_core.custom_agents.sql_query_cache_agent import ( + SqlQueryCacheAgentCustomAgent, +) import json import logging @@ -18,7 +20,7 @@ def __init__(self): "An agent that fetches the queries from the cache based on the user question.", ) - self.sql_connector = ConnectorFactory.get_database_connector() + self.agent = SqlQueryCacheAgentCustomAgent() @property def produced_message_types(self) -> List[type[ChatMessage]]: @@ -49,31 +51,9 @@ async def on_messages_stream( # If not JSON array, process as single question raise ValueError("Could not load message") - # Initialize results dictionary - cached_results = { - "cached_questions_and_schemas": [], - "contains_pre_run_results": False, - } - - # Process each question sequentially - for question in user_questions: - # Fetch the queries from the cache based on the question - logging.info(f"Fetching queries from cache for question: {question}") - cached_query = await self.sql_connector.fetch_queries_from_cache( - question, injected_parameters=injected_parameters - ) - - # If any question has pre-run results, set the flag - if cached_query.get("contains_pre_run_results", False): - cached_results["contains_pre_run_results"] = True - - # Add the cached results for this question - if cached_query.get("cached_questions_and_schemas"): - cached_results["cached_questions_and_schemas"].extend( - cached_query["cached_questions_and_schemas"] - ) - - logging.info(f"Final cached results: {cached_results}") + cached_results = await self.agent.process_message( + user_questions, injected_parameters + ) yield Response( chat_message=TextMessage( content=json.dumps(cached_results), source=self.name diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index b634538..cd6fe44 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -6,12 +6,11 @@ from autogen_agentchat.base import Response from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage from autogen_core import CancellationToken -from text_2_sql_core.connectors.factory import ConnectorFactory import json import logging -from text_2_sql_core.prompts.load import load -from jinja2 import Template -import asyncio +from text_2_sql_core.custom_agents.sql_schema_selection_agent import ( + SqlSchemaSelectionAgentCustomAgent, +) class SqlSchemaSelectionAgent(BaseChatAgent): @@ -21,15 +20,7 @@ def __init__(self, **kwargs): "An agent that fetches the schemas from the cache based on the user question.", ) - self.ai_search_connector = ConnectorFactory.get_ai_search_connector() - - self.open_ai_connector = ConnectorFactory.get_open_ai_connector() - - self.sql_connector = ConnectorFactory.get_database_connector() - - system_prompt = load("sql_schema_selection_agent")["system_message"] - - self.system_prompt = Template(system_prompt).render(kwargs) + self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs) @property def produced_message_types(self) -> List[type[ChatMessage]]: @@ -49,64 +40,15 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - last_response = messages[-1].content - - # load the json of the last message and get the user question's - - user_questions = json.loads(last_response) - - logging.info(f"User questions: {user_questions}") - - entity_tasks = [] - - for user_question in user_questions: - messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": user_question}, - ] - entity_tasks.append(self.open_ai_connector.run_completion_request(messages)) - - entity_results = await asyncio.gather(*entity_tasks) - - entity_search_tasks = [] - column_search_tasks = [] - - for entity_result in entity_results: - loaded_entity_result = json.loads(entity_result) - - logging.info(f"Loaded entity result: {loaded_entity_result}") - - for entity_group in loaded_entity_result["entities"]: - entity_search_tasks.append( - self.sql_connector.get_entity_schemas( - " ".join(entity_group), as_json=False - ) - ) - - for filter_condition in loaded_entity_result["filter_conditions"]: - column_search_tasks.append( - self.ai_search_connector.get_column_values( - filter_condition, as_json=False - ) - ) - - schemas_results = await asyncio.gather(*entity_search_tasks) - column_value_results = await asyncio.gather(*column_search_tasks) - - # deduplicate schemas - final_schemas = [] - - for schema_result in schemas_results: - for schema in schema_result: - if schema not in final_schemas: - final_schemas.append(schema) - - final_results = { - "COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results, - "SCHEMA_OPTIONS": final_schemas, - } - - logging.info(f"Final results: {final_results}") + try: + request_details = json.loads(messages[0].content) + user_questions = request_details["question"] + logging.info(f"Processing questions: {user_questions}") + except json.JSONDecodeError: + # If not JSON array, process as single question + raise ValueError("Could not load message") + + final_results = await self.agent.process_message(user_questions) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index b3c03b7..7eb26f4 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -9,10 +9,11 @@ import logging import base64 from datetime import datetime, timezone -import json from typing import Annotated from text_2_sql_core.connectors.open_ai import OpenAIConnector +from text_2_sql_core.utils.database import DatabaseEngineSpecificFields + class AISearchConnector: def __init__(self): @@ -109,7 +110,6 @@ async def get_column_values( str, "The text to run a semantic search against. Relevant entities will be returned.", ], - as_json: bool = True, ): """Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. @@ -137,24 +137,7 @@ async def get_column_values( minimum_score=5, ) - # build into a common format - column_values = {} - - for value in values: - trimmed_fqn = ".".join(value["FQN"].split(".")[:-1]) - if trimmed_fqn not in column_values: - column_values[trimmed_fqn] = [] - - column_values[trimmed_fqn].append(value["Value"]) - - logging.info("Column Values: %s", column_values) - - filter_to_column = {text: column_values} - - if as_json: - return json.dumps(filter_to_column, default=str) - else: - return filter_to_column + return values async def get_entity_schemas( self, @@ -167,7 +150,7 @@ async def get_entity_schemas( "The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.", ] = [], engine_specific_fields: Annotated[ - list[str], + list[DatabaseEngineSpecificFields], "The fields specific to the engine to be included in the search results.", ] = [], ) -> str: @@ -183,6 +166,8 @@ async def get_entity_schemas( logging.info("Search Text: %s", text) + stringified_engine_specific_fields = list(map(str, engine_specific_fields)) + retrieval_fields = [ "FQN", "Entity", @@ -192,7 +177,7 @@ async def get_entity_schemas( "Columns", "EntityRelationships", "CompleteEntityRelationshipsGraph", - ] + engine_specific_fields + ] + stringified_engine_specific_fields schemas = await self.run_ai_search_query( text, @@ -205,6 +190,8 @@ async def get_entity_schemas( top=3, ) + fqn_to_trim = ".".join(stringified_engine_specific_fields) + if len(excluded_entities) == 0: return schemas @@ -218,12 +205,16 @@ async def get_entity_schemas( and len(schema["CompleteEntityRelationshipsGraph"]) == 0 ): del schema["CompleteEntityRelationshipsGraph"] + else: + schema["CompleteEntityRelationshipsGraph"] = list( + map( + lambda x: x.replace(fqn_to_trim, ""), + schema["CompleteEntityRelationshipsGraph"], + ) + ) - if ( - schema["SammpleValues"] is not None - and len(schema["SammpleValues"]) == 0 - ): - del schema["SammpleValues"] + if schema["SampleValues"] is not None and len(schema["SampleValues"]) == 0: + del schema["SampleValues"] if ( schema["EntityRelationships"] is not None diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index 6271c12..c72322c 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -135,7 +135,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities, engine_specific_fields=["Catalog"] + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py index f57361f..707daae 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py @@ -28,7 +28,12 @@ def get_authentication_properties(cls) -> dict: return token_provider, api_key async def run_completion_request( - self, messages: list[dict], temperature=0, max_tokens=2000, model="4o-mini" + self, + messages: list[dict], + temperature=0, + max_tokens=2000, + model="4o-mini", + response_format=None, ) -> str: if model == "4o-mini": model_deployment = os.environ["OpenAI__MiniCompletionDeployment"] @@ -45,13 +50,29 @@ async def run_completion_request( azure_ad_token_provider=token_provider, api_key=api_key, ) as open_ai_client: - response = await open_ai_client.chat.completions.create( - model=model_deployment, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - ) - return response.choices[0].message.content + if response_format is not None: + response = await open_ai_client.beta.chat.completions.parse( + model=model_deployment, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format=response_format, + ) + else: + response = await open_ai_client.chat.completions.create( + model=model_deployment, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + message = response.choices[0].message + if response_format is not None and message.parsed is not None: + return message.parsed + elif response_format is not None: + return message.refusal + else: + return message.content async def run_embedding_request(self, batch: list[str]): token_provider, api_key = self.get_authentication_properties() diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py index 522554b..f106041 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py @@ -111,7 +111,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: @@ -119,6 +119,7 @@ async def get_entity_schemas( del schema["Entity"] del schema["Schema"] + del schema["Database"] if as_json: return json.dumps(schemas, default=str) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py index eb8d4b8..e71cd01 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py @@ -150,7 +150,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities, engine_specific_fields=["Warehouse", "Database"] + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 21a9378..87cd30a 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -74,6 +74,48 @@ async def query_execution( list[dict]: The results of the SQL query. """ + async def get_column_values( + self, + text: Annotated[ + str, + "The text to run a semantic search against. Relevant entities will be returned.", + ], + as_json: bool = True, + ): + """Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. + + Args: + ---- + text (str): The text to run the search against. + + Returns: + ------- + str: The values of the column in JSON format. + """ + + values = await self.ai_search_connector.get_column_values(text) + + # build into a common format + column_values = {} + + starting = len(self.engine_specific_fields) + + for value in values: + trimmed_fqn = ".".join(value["FQN"].split(".")[starting:-1]) + if trimmed_fqn not in column_values: + column_values[trimmed_fqn] = [] + + column_values[trimmed_fqn].append(value["Value"]) + + logging.info("Column Values: %s", column_values) + + filter_to_column = {text: column_values} + + if as_json: + return json.dumps(filter_to_column, default=str) + else: + return filter_to_column + @abstractmethod async def get_entity_schemas( self, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py index 58e8e53..36c1f83 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py @@ -124,7 +124,7 @@ async def get_entity_schemas( """ schemas = await self.ai_search_connector.get_entity_schemas( - text, excluded_entities + text, excluded_entities, engine_specific_fields=self.engine_specific_fields ) for schema in schemas: @@ -132,6 +132,7 @@ async def get_entity_schemas( del schema["Entity"] del schema["Schema"] + del schema["Database"] if as_json: return json.dumps(schemas, default=str) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py new file mode 100644 index 0000000..958d1a4 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from text_2_sql_core.connectors.factory import ConnectorFactory +import logging + + +class SqlQueryCacheAgentCustomAgent: + def __init__(self): + self.sql_connector = ConnectorFactory.get_database_connector() + + async def process_message( + self, user_questions: list[str], injected_parameters: dict + ) -> dict: + # Initialize results dictionary + cached_results = { + "cached_questions_and_schemas": [], + "contains_pre_run_results": False, + } + + # Process each question sequentially + for question in user_questions: + # Fetch the queries from the cache based on the question + logging.info(f"Fetching queries from cache for question: {question}") + cached_query = await self.sql_connector.fetch_queries_from_cache( + question, injected_parameters=injected_parameters + ) + + # If any question has pre-run results, set the flag + if cached_query.get("contains_pre_run_results", False): + cached_results["contains_pre_run_results"] = True + + # Add the cached results for this question + if cached_query.get("cached_questions_and_schemas"): + cached_results["cached_questions_and_schemas"].extend( + cached_query["cached_questions_and_schemas"] + ) + + logging.info(f"Final cached results: {cached_results}") + return cached_results diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py new file mode 100644 index 0000000..b02c5e6 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from text_2_sql_core.connectors.factory import ConnectorFactory +import logging +from text_2_sql_core.prompts.load import load +from jinja2 import Template +import asyncio +from text_2_sql_core.structured_outputs.sql_schema_selection_agent import ( + SQLSchemaSelectionAgentOutput, +) + + +class SqlSchemaSelectionAgentCustomAgent: + def __init__(self, **kwargs): + self.ai_search_connector = ConnectorFactory.get_ai_search_connector() + + self.open_ai_connector = ConnectorFactory.get_open_ai_connector() + + self.sql_connector = ConnectorFactory.get_database_connector() + + system_prompt = load("sql_schema_selection_agent")["system_message"] + + self.system_prompt = Template(system_prompt).render(kwargs) + + async def process_message(self, user_questions: list[str]) -> dict: + logging.info(f"User questions: {user_questions}") + + entity_tasks = [] + + for user_question in user_questions: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_question}, + ] + entity_tasks.append( + self.open_ai_connector.run_completion_request( + messages, response_format=SQLSchemaSelectionAgentOutput + ) + ) + + entity_results = await asyncio.gather(*entity_tasks) + + entity_search_tasks = [] + column_search_tasks = [] + + for entity_result in entity_results: + logging.info(f"Entity result: {entity_result}") + + for entity_group in entity_result.entities: + entity_search_tasks.append( + self.sql_connector.get_entity_schemas( + " ".join(entity_group), as_json=False + ) + ) + + for filter_condition in entity_result.filter_conditions: + column_search_tasks.append( + self.sql_connector.get_column_values( + filter_condition, as_json=False + ) + ) + + schemas_results = await asyncio.gather(*entity_search_tasks) + column_value_results = await asyncio.gather(*column_search_tasks) + + # deduplicate schemas + final_schemas = [] + + for schema_result in schemas_results: + for schema in schema_result: + if schema not in final_schemas: + final_schemas.append(schema) + + final_results = { + "COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results, + "SCHEMA_OPTIONS": final_schemas, + } + + logging.info(f"Final results: {final_results}") + + return final_results diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml index c8fe5b8..32ab897 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml @@ -154,6 +154,7 @@ system_message: * Detailed syntax corrections * Query execution * Result formatting + - For a given entity, use the 'SelectFromEntity' property in the SELECT FROM part of the SQL query. If the property is {'SelectFromEntity': 'test_schema.test_table'}, the select statement will be formulated from 'SELECT FROM test_schema.test_table WHERE . Remember: Your job is to focus on the data relationships and logic while following basic {{ target_engine }} patterns. diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py index 313954f..482e140 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/load.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import yaml import os diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/sql_schema_selection_agent.py new file mode 100644 index 0000000..dd5e6c6 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/sql_schema_selection_agent.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from pydantic import BaseModel + + +class SQLSchemaSelectionAgentOutput(BaseModel): + entities: list[list[str]] + filter_conditions: list[str]