diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 05ce703..edcf12b 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -50,7 +50,7 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload" + "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload" ] }, { @@ -100,7 +100,7 @@ "metadata": {}, "outputs": [], "source": [ - "async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n", + "async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"What is the total number of sales?\")):\n", " logging.info(\"Received %s Message from Text2SQL System\", message)" ] }, diff --git a/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb b/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb index 8b90716..2b38a09 100644 --- a/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb +++ b/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb @@ -68,7 +68,7 @@ "# Add the src directory to the path\n", "sys.path.append(str(notebook_dir / \"src\"))\n", "\n", - "from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload\n", + "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n", "from autogen_text_2_sql.evaluation_utils import get_final_sql_query\n", "\n", "# Configure logging\n", @@ -127,7 +127,7 @@ " all_queries = []\n", " final_query = None\n", " \n", - " async for message in autogen_text2sql.process_question(QuestionPayload(question=question)):\n", + " async for message in autogen_text2sql.process_user_message(UserMessagePayload(user_message=question)):\n", " if message.payload_type == \"answer_with_sources\":\n", " # Extract from results\n", " if hasattr(message.body, 'results'):\n", diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index cda320f..bf72f7d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql -from text_2_sql_core.payloads.interaction_payloads import QuestionPayload +from text_2_sql_core.payloads.interaction_payloads import UserMessagePayload -__all__ = ["AutoGenText2Sql", "QuestionPayload"] +__all__ = ["AutoGenText2Sql", "UserMessagePayload"] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index f65f54c..45ca468 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -19,9 +19,9 @@ import re from text_2_sql_core.payloads.interaction_payloads import ( - QuestionPayload, + UserMessagePayload, AnswerWithSourcesPayload, - DismabiguationRequestPayload, + DismabiguationRequestsPayload, ProcessingUpdatePayload, InteractionPayload, PayloadType, @@ -40,8 +40,8 @@ def get_all_agents(self): # Get current datetime for the Query Rewrite Agent current_datetime = datetime.now() - self.question_rewrite_agent = LLMAgentCreator.create( - "question_rewrite_agent", current_datetime=current_datetime + self.user_message_rewrite_agent = LLMAgentCreator.create( + "user_message_rewrite_agent", current_datetime=current_datetime ) self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs) @@ -49,7 +49,7 @@ def get_all_agents(self): self.answer_agent = LLMAgentCreator.create("answer_agent") agents = [ - self.question_rewrite_agent, + self.user_message_rewrite_agent, self.parallel_query_solving_agent, self.answer_agent, ] @@ -62,7 +62,7 @@ def termination_condition(self): termination = ( TextMentionTermination("TERMINATE") | SourceMatchTermination("answer_agent") - | TextMentionTermination("requires_user_information_request") + | TextMentionTermination("contains_disambiguation_requests") | MaxMessageTermination(5) ) return termination @@ -73,11 +73,11 @@ def unified_selector(self, messages): current_agent = messages[-1].source if messages else "user" decision = None - # If this is the first message start with question_rewrite_agent + # If this is the first message start with user_message_rewrite_agent if current_agent == "user": - decision = "question_rewrite_agent" + decision = "user_message_rewrite_agent" # Handle transition after query rewriting - elif current_agent == "question_rewrite_agent": + elif current_agent == "user_message_rewrite_agent": decision = "parallel_query_solving_agent" # Handle transition after parallel query solving elif current_agent == "parallel_query_solving_agent": @@ -102,15 +102,6 @@ def agentic_flow(self): ) return flow - def extract_disambiguation_request( - self, messages: list - ) -> DismabiguationRequestPayload: - """Extract the disambiguation request from the answer.""" - disambiguation_request = messages[-1].content - return DismabiguationRequestPayload( - disambiguation_request=disambiguation_request, - ) - def parse_message_content(self, content): """Parse different message content formats into a dictionary.""" if isinstance(content, (list, dict)): @@ -134,6 +125,49 @@ def parse_message_content(self, content): # If all parsing attempts fail, return the content as-is return content + def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]: + """Extract the decomposed messages from the answer.""" + # Only load sub-message results if we have a database result + sub_message_results = self.parse_message_content(messages[1].content) + logging.info("Decomposed Results: %s", sub_message_results) + + decomposed_user_messages = sub_message_results.get( + "decomposed_user_messages", [] + ) + + logging.debug( + "Returning decomposed_user_messages: %s", decomposed_user_messages + ) + + return decomposed_user_messages + + def extract_disambiguation_request( + self, messages: list + ) -> DismabiguationRequestsPayload: + """Extract the disambiguation request from the answer.""" + all_disambiguation_requests = self.parse_message_content(messages[-1].content) + + decomposed_user_messages = self.extract_decomposed_user_messages(messages) + request_payload = DismabiguationRequestsPayload( + decomposed_user_messages=decomposed_user_messages + ) + + for per_question_disambiguation_request in all_disambiguation_requests[ + "disambiguation_requests" + ].values(): + for disambiguation_request in per_question_disambiguation_request: + logging.info( + "Disambiguation Request Identified: %s", disambiguation_request + ) + + request = DismabiguationRequestsPayload.Body.DismabiguationRequest( + agent_question=disambiguation_request["agent_question"], + user_choices=disambiguation_request["user_choices"], + ) + request_payload.body.disambiguation_requests.append(request) + + return request_payload + def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: """Extract the sources from the answer.""" answer = messages[-1].content @@ -145,41 +179,35 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: except json.JSONDecodeError: logging.warning("Unable to read SQL query results: %s", sql_query_results) sql_query_results = {} - sub_question_results = {} - else: - # Only load sub-question results if we have a database result - sub_question_results = self.parse_message_content(messages[1].content) - logging.info("Sub-Question Results: %s", sub_question_results) try: - sub_questions = [ - sub_question - for sub_question_group in sub_question_results.get("sub_questions", []) - for sub_question in sub_question_group - ] + decomposed_user_messages = self.extract_decomposed_user_messages(messages) logging.info("SQL Query Results: %s", sql_query_results) payload = AnswerWithSourcesPayload( - answer=answer, sub_questions=sub_questions + answer=answer, decomposed_user_messages=decomposed_user_messages ) if not isinstance(sql_query_results, dict): logging.error(f"Expected dict, got {type(sql_query_results)}") return payload - if "results" not in sql_query_results: + if "database_results" not in sql_query_results: logging.error("No 'results' key in sql_query_results") return payload - for question, sql_query_result_list in sql_query_results["results"].items(): + for message, sql_query_result_list in sql_query_results[ + "database_results" + ].items(): if not sql_query_result_list: # Check if list is empty - logging.warning(f"No results for question: {question}") + logging.warning(f"No results for message: {message}") continue for sql_query_result in sql_query_result_list: if not isinstance(sql_query_result, dict): logging.error( - f"Expected dict for sql_query_result, got {type(sql_query_result)}" + "Expected dict for sql_query_result, got %s", + type(sql_query_result), ) continue @@ -208,16 +236,16 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: answer=f"{answer}\nError processing results: {str(e)}" ) - async def process_question( + async def process_user_message( self, - question_payload: QuestionPayload, + message_payload: UserMessagePayload, chat_history: list[InteractionPayload] = None, ) -> AsyncGenerator[InteractionPayload, None]: - """Process the complete question through the unified system. + """Process the complete message through the unified system. Args: ---- - task (str): The user question to process. + task (str): The user message to process. chat_history (list[str], optional): The chat history. Defaults to None. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. @@ -225,22 +253,22 @@ async def process_question( ------- dict: The response from the system. """ - logging.info("Processing question: %s", question_payload.body.question) + logging.info("Processing message: %s", message_payload.body.user_message) logging.info("Chat history: %s", chat_history) agent_input = { - "question": question_payload.body.question, + "message": message_payload.body.user_message, "chat_history": {}, - "injected_parameters": question_payload.body.injected_parameters, + "injected_parameters": message_payload.body.injected_parameters, } if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): - if chat.root.payload_type == PayloadType.QUESTION: + if chat.root.payload_type == PayloadType.USER_MESSAGE: # For now only consider the user query chat_history_key = f"chat_{idx}" - agent_input[chat_history_key] = chat.root.body.question + agent_input[chat_history_key] = chat.root.body.user_message async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -248,7 +276,7 @@ async def process_question( payload = None if isinstance(message, TextMessage): - if message.source == "question_rewrite_agent": + if message.source == "user_message_rewrite_agent": payload = ProcessingUpdatePayload( message="Rewriting the query...", ) @@ -271,10 +299,10 @@ async def process_question( elif message.messages[-1].source == "parallel_query_solving_agent": # Load into disambiguation request payload = self.extract_disambiguation_request(message.messages) - elif message.messages[-1].source == "question_rewrite_agent": + elif message.messages[-1].source == "user_message_rewrite_agent": # Load into empty response payload = AnswerWithSourcesPayload( - answer="Apologies, I cannot answer that question as it is not relevant. Please try another question or rephrase your current question." + answer="Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message." ) else: logging.error("Unexpected TaskResult: %s", message) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index 5a26c7a..886310d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -43,7 +43,7 @@ def get_tool(cls, sql_helper, tool_name: str): elif tool_name == "sql_get_entity_schemas_tool": return FunctionToolAlias( sql_helper.get_entity_schemas, - description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.", + description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user input and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.", ) elif tool_name == "sql_get_column_values_tool": return FunctionToolAlias( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index e0b0888..c93db3e 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -18,6 +18,18 @@ from json import JSONDecodeError import re import os +from pydantic import BaseModel, Field + + +class FilteredParallelMessagesCollection(BaseModel): + database_results: dict[str, list] = Field(default_factory=dict) + disambiguation_requests: dict[str, list] = Field(default_factory=dict) + + def add_identifier(self, identifier): + if identifier not in self.database_results: + self.database_results[identifier] = [] + if identifier not in self.disambiguation_requests: + self.disambiguation_requests[identifier] = [] class ParallelQuerySolvingAgent(BaseChatAgent): @@ -84,12 +96,12 @@ async def on_messages_stream( injected_parameters = {} # Load the json of the last message to populate the final output object - question_rewrites = json.loads(last_response) + message_rewrites = json.loads(last_response) - logging.info(f"Query Rewrites: {question_rewrites}") + logging.info(f"Query Rewrites: {message_rewrites}") async def consume_inner_messages_from_agentic_flow( - agentic_flow, identifier, database_results + agentic_flow, identifier, filtered_parallel_messages ): """ Consume the inner messages and append them to the specified list. @@ -101,8 +113,7 @@ async def consume_inner_messages_from_agentic_flow( """ async for inner_message in agentic_flow: # Add message to results dictionary, tagged by the function name - if identifier not in database_results: - database_results[identifier] = [] + filtered_parallel_messages.add_identifier(identifier) logging.info(f"Checking Inner Message: {inner_message}") @@ -122,7 +133,9 @@ async def consume_inner_messages_from_agentic_flow( == "query_execution_with_limit" ): logging.info("Contains query results") - database_results[identifier].append( + filtered_parallel_messages.database_results[ + identifier + ].append( { "sql_query": parsed_message[ "sql_query" @@ -138,14 +151,17 @@ async def consume_inner_messages_from_agentic_flow( # Search for specific message types and add them to the final output object if isinstance(parsed_message, dict): + # Check if the message contains pre-run results if ("contains_pre_run_results" in parsed_message) and ( parsed_message["contains_pre_run_results"] is True ): logging.info("Contains pre-run results") for pre_run_sql_query, pre_run_result in parsed_message[ - "cached_questions_and_schemas" + "cached_messages_and_schemas" ].items(): - database_results[identifier].append( + filtered_parallel_messages.database_results[ + identifier + ].append( { "sql_query": pre_run_sql_query.replace( "\n", " " @@ -153,6 +169,17 @@ async def consume_inner_messages_from_agentic_flow( "sql_rows": pre_run_result["sql_rows"], } ) + # Check if disambiguation is required + elif ("disambiguation_requests" in parsed_message) and ( + parsed_message["disambiguation_requests"] + ): + logging.info("Contains disambiguation requests") + for disambiguation_request in parsed_message[ + "disambiguation_requests" + ]: + filtered_parallel_messages.disambiguation_requests[ + identifier + ].append(disambiguation_request) except Exception as e: logging.warning(f"Error processing message: {e}") @@ -160,11 +187,11 @@ async def consume_inner_messages_from_agentic_flow( yield inner_message inner_solving_generators = [] - database_results = {} + filtered_parallel_messages = FilteredParallelMessagesCollection() # Convert all_non_database_query to lowercase string and compare all_non_database_query = str( - question_rewrites.get("all_non_database_query", "false") + message_rewrites.get("all_non_database_query", "false") ).lower() if all_non_database_query == "true": @@ -177,12 +204,12 @@ async def consume_inner_messages_from_agentic_flow( return # Start processing sub-queries - for question_rewrite in question_rewrites["sub_questions"]: - logging.info(f"Processing sub-query: {question_rewrite}") + for message_rewrite in message_rewrites["decomposed_user_messages"]: + logging.info(f"Processing sub-query: {message_rewrite}") # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs) - identifier = ", ".join(question_rewrite) + identifier = ", ".join(message_rewrite) # Add database connection info to injected parameters query_params = injected_parameters.copy() if injected_parameters else {} @@ -196,12 +223,12 @@ async def consume_inner_messages_from_agentic_flow( # Launch tasks for each sub-query inner_solving_generators.append( consume_inner_messages_from_agentic_flow( - inner_autogen_text_2_sql.process_question( - question=question_rewrite, + inner_autogen_text_2_sql.process_user_message( + user_message=message_rewrite, injected_parameters=query_params, ), identifier, - database_results, + filtered_parallel_messages, ) ) @@ -218,17 +245,43 @@ async def consume_inner_messages_from_agentic_flow( yield inner_message # Log final results for debugging or auditing - logging.info(f"Database Results: {database_results}") + logging.info( + "Database Results: %s", filtered_parallel_messages.database_results + ) + logging.info( + "Disambiguation Requests: %s", + filtered_parallel_messages.disambiguation_requests, + ) - # Final response - yield Response( - chat_message=TextMessage( - content=json.dumps( - {"contains_results": True, "results": database_results} + if ( + max(map(len, filtered_parallel_messages.disambiguation_requests.values())) + > 0 + ): + # Final response + yield Response( + chat_message=TextMessage( + content=json.dumps( + { + "contains_disambiguation_requests": True, + "disambiguation_requests": filtered_parallel_messages.disambiguation_requests, + } + ), + source=self.name, ), - source=self.name, - ), - ) + ) + else: + # Final response + yield Response( + chat_message=TextMessage( + content=json.dumps( + { + "contains_database_results": True, + "database_results": filtered_parallel_messages.database_results, + } + ), + source=self.name, + ), + ) async def on_reset(self, cancellation_token: CancellationToken) -> None: pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index b7d0072..d17b886 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -17,7 +17,7 @@ class SqlQueryCacheAgent(BaseChatAgent): def __init__(self): super().__init__( "sql_query_cache_agent", - "An agent that fetches the queries from the cache based on the user question.", + "An agent that fetches the queries from the cache based on the user message.", ) self.agent = SqlQueryCacheAgentCustomAgent() @@ -40,19 +40,19 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - # Get the decomposed questions from the question_rewrite_agent + # Get the decomposed messages from the user_message_rewrite_agent try: request_details = json.loads(messages[0].content) injected_parameters = request_details["injected_parameters"] - user_questions = request_details["question"] - logging.info(f"Processing questions: {user_questions}") + user_messages = request_details["user_message"] + logging.info(f"Processing messages: {user_messages}") logging.info(f"Input Parameters: {injected_parameters}") except json.JSONDecodeError: - # If not JSON array, process as single question + # If not JSON array, process as single message raise ValueError("Could not load message") cached_results = await self.agent.process_message( - user_questions, injected_parameters + user_messages, injected_parameters ) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index c33219b..557d7a5 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -17,7 +17,7 @@ class SqlSchemaSelectionAgent(BaseChatAgent): def __init__(self, **kwargs): super().__init__( "sql_schema_selection_agent", - "An agent that fetches the schemas from the cache based on the user question.", + "An agent that fetches the schemas from the cache based on the user input.", ) self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs) @@ -43,19 +43,19 @@ async def on_messages_stream( # Try to parse as JSON first try: request_details = json.loads(messages[0].content) - user_questions = request_details["question"] + messages = request_details["question"] except (json.JSONDecodeError, KeyError): # If not JSON or missing question key, use content directly - user_questions = messages[0].content + messages = messages[0].content - if isinstance(user_questions, str): - user_questions = [user_questions] - elif not isinstance(user_questions, list): - user_questions = [str(user_questions)] + if isinstance(messages, str): + messages = [messages] + elif not isinstance(messages, list): + messages = [str(messages)] - logging.info(f"Processing questions: {user_questions}") + logging.info(f"Processing questions: {messages}") - final_results = await self.agent.process_message(user_questions) + final_results = await self.agent.process_message(messages) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py b/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py index 4f2f158..edb8980 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import re from typing import Optional, List, Dict, Any @@ -14,7 +16,7 @@ def extract_sql_queries_from_results(results: Dict[str, Any]) -> List[str]: """ queries = [] - if results.get("contains_results") and results.get("results"): + if results.get("contains_database_results") and results.get("results"): for question_results in results["results"].values(): for result in question_results: if isinstance(result, dict) and "sql_query" in result: diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 5b48ab6..a83000d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -124,7 +124,11 @@ def get_all_agents(self): @property def termination_condition(self): """Define the termination condition for the chat.""" - termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10) + termination = ( + TextMentionTermination("TERMINATE") + | MaxMessageTermination(10) + | TextMentionTermination("disambiguation_request") + ) return termination def unified_selector(self, messages): @@ -169,31 +173,30 @@ def agentic_flow(self): ) return flow - def process_question( + def process_user_message( self, - question: str, + user_message: str, injected_parameters: dict = None, ): """Process the complete question through the unified system. Args: ---- - task (str): The user question to process. + task (str): The user input to process. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: ------- dict: The response from the system. """ - logging.info("Processing question: %s", question) + logging.info("Processing question: %s", user_message) # Update environment with injected parameters self._update_environment(injected_parameters) try: agent_input = { - "question": question, - "chat_history": {}, + "user_message": user_message, "injected_parameters": injected_parameters, } diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 70ef72a..46feb4d 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -187,7 +187,8 @@ async def query_execution_with_limit( "type": "query_execution_with_limit", "sql_query": sql_query, "sql_rows": result, - } + }, + default=str, ) else: return json.dumps( @@ -195,7 +196,8 @@ async def query_execution_with_limit( "type": "errored_query_execution_with_limit", "sql_query": sql_query, "errors": validation_result, - } + }, + default=str, ) async def query_validation( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py index 958d1a4..b58f9b5 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py @@ -9,7 +9,7 @@ def __init__(self): self.sql_connector = ConnectorFactory.get_database_connector() async def process_message( - self, user_questions: list[str], injected_parameters: dict + self, messages: list[str], injected_parameters: dict ) -> dict: # Initialize results dictionary cached_results = { @@ -18,11 +18,11 @@ async def process_message( } # Process each question sequentially - for question in user_questions: + for message in messages: # Fetch the queries from the cache based on the question - logging.info(f"Fetching queries from cache for question: {question}") + logging.info(f"Fetching queries from cache for question: {message}") cached_query = await self.sql_connector.fetch_queries_from_cache( - question, injected_parameters=injected_parameters + message, injected_parameters=injected_parameters ) # If any question has pre-run results, set the flag diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py index a2e0293..adfba3b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py @@ -22,15 +22,15 @@ def __init__(self, **kwargs): self.system_prompt = Template(system_prompt).render(kwargs) - async def process_message(self, user_questions: list[str]) -> dict: - logging.info(f"User questions: {user_questions}") + async def process_message(self, messages: list[str]) -> dict: + logging.info(f"user inputs: {messages}") entity_tasks = [] - for user_question in user_questions: + for message in messages: messages = [ {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": user_question}, + {"role": "user", "content": message}, ] entity_tasks.append( self.open_ai_connector.run_completion_request( @@ -47,7 +47,7 @@ async def process_message(self, user_questions: list[str]) -> dict: logging.info(f"Entity result: {entity_result}") for entity_group in entity_result.entities: - logging.info(f"Searching for schemas for entity group: {entity_group}") + logging.info("Searching for schemas for entity group: %s", entity_group) entity_search_tasks.append( self.sql_connector.get_entity_schemas( " ".join(entity_group), as_json=False @@ -56,7 +56,7 @@ async def process_message(self, user_questions: list[str]) -> dict: for filter_condition in entity_result.filter_conditions: logging.info( - f"Searching for column values for filter: {filter_condition}" + "Searching for column values for filter: %s", filter_condition ) column_search_tasks.append( self.sql_connector.get_column_values( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index ad97154..1e7e621 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pydantic import BaseModel, RootModel, Field, model_validator +from pydantic import BaseModel, RootModel, Field, model_validator, ConfigDict from enum import StrEnum from typing import Literal @@ -8,18 +8,6 @@ from uuid import uuid4 -class PayloadBase(BaseModel): - prompt_tokens: int | None = None - completion_tokens: int | None = None - message_id: str = Field(..., default_factory=lambda: str(uuid4())) - timestamp: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - description="Timestamp in UTC", - ) - payload_type: str - payload_source: str - - class PayloadSource(StrEnum): USER = "user" AGENT = "agent" @@ -27,25 +15,46 @@ class PayloadSource(StrEnum): class PayloadType(StrEnum): ANSWER_WITH_SOURCES = "answer_with_sources" - DISAMBIGUATION_REQUEST = "disambiguation_request" + DISAMBIGUATION_REQUESTS = "disambiguation_requests" PROCESSING_UPDATE = "processing_update" - QUESTION = "question" + USER_MESSAGE = "user_message" -class DismabiguationRequestPayload(PayloadBase): - class Body(BaseModel): - class DismabiguationRequest(BaseModel): - question: str - matching_columns: list[str] - matching_filter_values: list[str] - other_user_choices: list[str] +class InteractionPayloadBase(BaseModel): + model_config = ConfigDict(populate_by_name=True, extra="ignore") - disambiguation_requests: list[DismabiguationRequest] - payload_type: Literal[ - PayloadType.DISAMBIGUATION_REQUEST - ] = PayloadType.DISAMBIGUATION_REQUEST - payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT +class PayloadBase(InteractionPayloadBase): + message_id: str = Field( + ..., default_factory=lambda: str(uuid4()), alias="messageId" + ) + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Timestamp in UTC", + ) + payload_type: PayloadType = Field(..., alias="payloadType") + payload_source: PayloadSource = Field(..., alias="payloadSource") + + +class DismabiguationRequestsPayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): + class DismabiguationRequest(InteractionPayloadBase): + agent_question: str | None = Field(..., alias="agentQuestion") + user_choices: list[str] | None = Field(default=None, alias="userChoices") + + disambiguation_requests: list[DismabiguationRequest] | None = Field( + default_factory=list, alias="disambiguationRequests" + ) + decomposed_user_messages: list[list[str]] = Field( + default_factory=list, alias="decomposedUserMessages" + ) + + payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field( + PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType" + ) + payload_source: Literal[PayloadSource.AGENT] = Field( + default=PayloadSource.AGENT, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): @@ -54,20 +63,24 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class AnswerWithSourcesPayload(PayloadBase): - class Body(BaseModel): - class Source(BaseModel): - sql_query: str - sql_rows: list[dict] +class AnswerWithSourcesPayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): + class Source(InteractionPayloadBase): + sql_query: str = Field(alias="sqlQuery") + sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows") answer: str - sub_questions: list[str] = Field(default_factory=list) + decomposed_user_messages: list[list[str]] = Field( + default_factory=list, alias="decomposedUserMessages" + ) sources: list[Source] = Field(default_factory=list) - payload_type: Literal[ - PayloadType.ANSWER_WITH_SOURCES - ] = PayloadType.ANSWER_WITH_SOURCES - payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field( + PayloadType.ANSWER_WITH_SOURCES, alias="payloadType" + ) + payload_source: Literal[PayloadSource.AGENT] = Field( + PayloadSource.AGENT, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): @@ -76,13 +89,17 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class ProcessingUpdatePayload(PayloadBase): - class Body(BaseModel): +class ProcessingUpdatePayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): title: str | None = "Processing..." message: str | None = "Processing..." - payload_type: Literal[PayloadType.PROCESSING_UPDATE] = PayloadType.PROCESSING_UPDATE - payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + payload_type: Literal[PayloadType.PROCESSING_UPDATE] = Field( + PayloadType.PROCESSING_UPDATE, alias="payloadType" + ) + payload_source: Literal[PayloadSource.AGENT] = Field( + PayloadSource.AGENT, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): @@ -91,10 +108,12 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class QuestionPayload(PayloadBase): - class Body(BaseModel): - question: str - injected_parameters: dict = Field(default_factory=dict) +class UserMessagePayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): + user_message: str = Field(..., alias="userMessage") + injected_parameters: dict = Field( + default_factory=dict, alias="injectedParameters" + ) @model_validator(mode="before") def add_defaults(cls, values): @@ -108,8 +127,12 @@ def add_defaults(cls, values): values["injected_parameters"] = {**defaults, **injected} return values - payload_type: Literal[PayloadType.QUESTION] = PayloadType.QUESTION - payload_source: Literal[PayloadSource.USER] = PayloadSource.USER + payload_type: Literal[PayloadType.USER_MESSAGE] = Field( + PayloadType.USER_MESSAGE, alias="payloadType" + ) + payload_source: Literal[PayloadSource.USER] = Field( + PayloadSource.USER, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): @@ -119,6 +142,6 @@ def __init__(self, **kwargs): class InteractionPayload(RootModel): - root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestPayload | AnswerWithSourcesPayload = Field( + root: UserMessagePayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field( ..., discriminator="payload_type" ) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml index e31ed70..df8c0c0 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml @@ -170,12 +170,16 @@ system_message: If disambiguation needed: { - \"disambiguation\": [{ - \"question\": \"\", - \"matching_columns\": [\"\", \"\"], - \"matching_filter_values\": [\"\", \"\"], - \"other_user_choices\": [\"\", \"\"] - }] + \"disambiguation_requests\": [ + { + \"agent_question\": \"\", + \"user_choices\": [\"\", \"\"] + }, + { + \"agent_question\": \"\", + \"user_choices\": [\"\", \"\"] + } + ] } TERMINATE diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_rewrite_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml similarity index 82% rename from text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_rewrite_agent.yaml rename to text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml index 134e0ec..a779d52 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_rewrite_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml @@ -1,8 +1,8 @@ model: "4o-mini" -description: "An agent that preprocesses user questions by decomposing complex queries into simpler sub-queries that can be processed independently and then combined." +description: "An agent that preprocesses user inputs by decomposing complex queries into simpler sub-messages that can be processed independently and then combined." system_message: | - You are a helpful AI Assistant specializing in breaking down complex questions into simpler sub-queries that can be processed independently and then combined for the final answer. You should identify when a question can be solved through simpler sub-queries and provide clear instructions for combining their results. + You are a helpful AI Assistant specializing in breaking down complex questions into simpler sub-messages that can be processed independently and then combined for the final answer. You should identify when a question can be solved through simpler sub-messages and provide clear instructions for combining their results. @@ -36,7 +36,7 @@ system_message: | 1. Question Filtering and Classification - Use the provided list of topics to filter out malicious or unrelated queries. - Ensure the question is relevant to the system's use case. - - If the question cannot be filtered, output an empty sub-query list in the JSON format. Followed by TERMINATE. + - If the question cannot be filtered, output an empty sub-message list in the JSON format. Followed by TERMINATE. - For non-database questions like greetings (e.g., "Hello", "What can you do?", "How are you?"), set "all_non_database_query" to true. - For questions about data (e.g., queries about records, counts, values, comparisons, or any questions that would require database access), set "all_non_database_query" to false. @@ -51,28 +51,27 @@ system_message: | - Determine if breaking down would simplify processing 4. Break Down Complex Queries: - - Create independent sub-queries that can be processed separately. - - Each sub-query should be a simple, focused task. - - Group dependent sub-queries together for sequential processing. - - Ensure each sub-query is simple and focused + - Create independent sub-messages that can be processed separately. + - Each sub-message should be a simple, focused task. + - Group dependent sub-messages together for sequential processing. - Include clear combination instructions - - Preserve all necessary context in each sub-query + - Preserve all necessary context in each sub-message 5. Handle Date References: - Resolve relative dates using {{ current_datetime }} - Maintain consistent YYYY-MM-DD format - - Include date context in each sub-query + - Include date context in each sub-message 6. Maintain Query Context: - - Each sub-query should be self-contained + - Each sub-message should be self-contained - Include all necessary filtering conditions - Preserve business context 1. Always consider if a complex query can be broken down - 2. Make sub-queries as simple as possible + 2. Make sub-messages as simple as possible 3. Include clear instructions for combining results - 4. Preserve all necessary context in each sub-query + 4. Preserve all necessary context in each sub-message 5. Resolve any relative dates before decomposition @@ -90,11 +89,11 @@ system_message: | - Return a JSON object with sub-queries and combination instructions: + Return a JSON object with sub-messages and combination instructions: { - "sub_questions": [ - [""], - [""], + "decomposed_user_messages": [ + [""], + [""], ... ], "combination_logic": "", @@ -109,7 +108,7 @@ system_message: | Input: "Which product categories have shown consistent growth quarter over quarter in 2008, and what were their top selling items?" Output: { - "sub_questions": [ + "decomposed_user_messages": [ ["Calculate quarterly sales totals by product category for 2008", "For these categories, find their top selling products in 2008"] ], "combination_logic": "First identify growing categories from quarterly analysis, then find their best-selling products", @@ -121,7 +120,7 @@ system_message: | Input: "How many orders did we have in 2008?" Output: { - "sub_questions": [ + "decomposed_user_messages": [ ["How many orders did we have in 2008?"] ], "combination_logic": "Direct count query, no combination needed", @@ -133,12 +132,12 @@ system_message: | Input: "Compare the sales performance of our top 5 products in Europe versus North America, including their market share in each region" Output: { - "sub_questions": [ + "decomposed_user_messages": [ ["Get total sales by product in European countries"], ["Get total sales by product in North American countries"], ["Calculate total market size for each region", "Find top 5 products by sales in each region"], ], - "combination_logic": "First identify top products in each region, then calculate and compare their market shares. Questions that depend on the result of each sub-query are combined.", + "combination_logic": "First identify top products in each region, then calculate and compare their market shares. Questions that depend on the result of each sub-message are combined.", "query_type": "complex", "all_non_database_query": "false" } @@ -147,7 +146,7 @@ system_message: | Input: "Hello, what can you help me with?" Output: { - "sub_questions": [ + "decomposed_user_messages": [ ["What are your capabilities?"] ], "combination_logic": "Simple greeting and capability question",