From 70b73c48c50039d0eb05d77c7402c4ae2cc6e03d Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 9 Jan 2025 11:00:01 +0000 Subject: [PATCH 01/10] Update interaction payload --- .../src/autogen_text_2_sql/autogen_text_2_sql.py | 9 +++++---- .../payloads/interaction_payloads.py | 13 ++++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index f65f54c..5983ca4 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -21,7 +21,7 @@ from text_2_sql_core.payloads.interaction_payloads import ( QuestionPayload, AnswerWithSourcesPayload, - DismabiguationRequestPayload, + DismabiguationRequestsPayload, ProcessingUpdatePayload, InteractionPayload, PayloadType, @@ -104,10 +104,10 @@ def agentic_flow(self): def extract_disambiguation_request( self, messages: list - ) -> DismabiguationRequestPayload: + ) -> DismabiguationRequestsPayload: """Extract the disambiguation request from the answer.""" disambiguation_request = messages[-1].content - return DismabiguationRequestPayload( + return DismabiguationRequestsPayload( disambiguation_request=disambiguation_request, ) @@ -179,7 +179,8 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: for sql_query_result in sql_query_result_list: if not isinstance(sql_query_result, dict): logging.error( - f"Expected dict for sql_query_result, got {type(sql_query_result)}" + "Expected dict for sql_query_result, got %s", + type(sql_query_result), ) continue diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index ad97154..cd70368 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -32,13 +32,16 @@ class PayloadType(StrEnum): QUESTION = "question" -class DismabiguationRequestPayload(PayloadBase): +class ColumnFilterPair(BaseModel): + column: str + filter_value: str | None = Field(default=None) + + +class DismabiguationRequestsPayload(PayloadBase): class Body(BaseModel): class DismabiguationRequest(BaseModel): question: str - matching_columns: list[str] - matching_filter_values: list[str] - other_user_choices: list[str] + choices: list[ColumnFilterPair] | None = Field(default=None) disambiguation_requests: list[DismabiguationRequest] @@ -119,6 +122,6 @@ def __init__(self, **kwargs): class InteractionPayload(RootModel): - root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestPayload | AnswerWithSourcesPayload = Field( + root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field( ..., discriminator="payload_type" ) From 7211486643b45eccafab9b3ec98f05c9becf6b8c Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 9 Jan 2025 16:18:05 +0000 Subject: [PATCH 02/10] Update payloads --- .../src/autogen_text_2_sql/__init__.py | 4 +-- .../autogen_text_2_sql/autogen_text_2_sql.py | 6 ++-- .../payloads/interaction_payloads.py | 35 ++++++++++++++++--- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index cda320f..942e106 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql -from text_2_sql_core.payloads.interaction_payloads import QuestionPayload +from text_2_sql_core.payloads.interaction_payloads import UserInputPayload -__all__ = ["AutoGenText2Sql", "QuestionPayload"] +__all__ = ["AutoGenText2Sql", "UserInputPayload"] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 5983ca4..bf804cc 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -19,7 +19,7 @@ import re from text_2_sql_core.payloads.interaction_payloads import ( - QuestionPayload, + UserInputPayload, AnswerWithSourcesPayload, DismabiguationRequestsPayload, ProcessingUpdatePayload, @@ -211,7 +211,7 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: async def process_question( self, - question_payload: QuestionPayload, + question_payload: UserInputPayload, chat_history: list[InteractionPayload] = None, ) -> AsyncGenerator[InteractionPayload, None]: """Process the complete question through the unified system. @@ -238,7 +238,7 @@ async def process_question( if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): - if chat.root.payload_type == PayloadType.QUESTION: + if chat.root.payload_type == PayloadType.USER_INPUT: # For now only consider the user query chat_history_key = f"chat_{idx}" agent_input[chat_history_key] = chat.root.body.question diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index cd70368..3ff36b4 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -29,10 +29,11 @@ class PayloadType(StrEnum): ANSWER_WITH_SOURCES = "answer_with_sources" DISAMBIGUATION_REQUEST = "disambiguation_request" PROCESSING_UPDATE = "processing_update" - QUESTION = "question" + USER_INPUT = "user_input" class ColumnFilterPair(BaseModel): + fqn: str column: str filter_value: str | None = Field(default=None) @@ -57,6 +58,32 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) +request = DismabiguationRequestsPayload( + disambiguation_requests=[ + { + "question": "Which of the following do you mean?", + "choices": [ + {"fqn": "", "column": "product_name", "filter_value": "Road Bike"}, + { + "fqn": "", + "column": "product_name", + "filter_value": "Mountain Bike", + }, + ], + }, + { + "question": "Do you mean total sales by volume or number of customers?", + "choices": [ + {"fqn": "", "column": "sales_volume", "filter_value": None}, + {"fqn": "", "column": "customer_count", "filter_value": None}, + ], + }, + ] +) + +print(request.model_dump()) + + class AnswerWithSourcesPayload(PayloadBase): class Body(BaseModel): class Source(BaseModel): @@ -94,7 +121,7 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class QuestionPayload(PayloadBase): +class UserInputPayload(PayloadBase): class Body(BaseModel): question: str injected_parameters: dict = Field(default_factory=dict) @@ -111,7 +138,7 @@ def add_defaults(cls, values): values["injected_parameters"] = {**defaults, **injected} return values - payload_type: Literal[PayloadType.QUESTION] = PayloadType.QUESTION + payload_type: Literal[PayloadType.USER_INPUT] = PayloadType.USER_INPUT payload_source: Literal[PayloadSource.USER] = PayloadSource.USER body: Body | None = Field(default=None) @@ -122,6 +149,6 @@ def __init__(self, **kwargs): class InteractionPayload(RootModel): - root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field( + root: UserInputPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field( ..., discriminator="payload_type" ) From 222c13b715e61e3f41fa0c5bfd41893598325dc7 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 9 Jan 2025 17:00:26 +0000 Subject: [PATCH 03/10] Update user input --- .../autogen_text_2_sql/autogen_text_2_sql.py | 60 ++++++++++--------- .../creators/llm_agent_creator.py | 2 +- .../parallel_query_solving_agent.py | 18 +++--- .../custom_agents/sql_query_cache_agent.py | 12 ++-- .../sql_schema_selection_agent.py | 18 +++--- .../inner_autogen_text_2_sql.py | 2 +- .../custom_agents/sql_query_cache_agent.py | 4 +- .../sql_schema_selection_agent.py | 12 ++-- ...ent.yaml => user_input_rewrite_agent.yaml} | 2 +- 9 files changed, 67 insertions(+), 63 deletions(-) rename text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/{question_rewrite_agent.yaml => user_input_rewrite_agent.yaml} (97%) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index bf804cc..9b3abb4 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -40,8 +40,8 @@ def get_all_agents(self): # Get current datetime for the Query Rewrite Agent current_datetime = datetime.now() - self.question_rewrite_agent = LLMAgentCreator.create( - "question_rewrite_agent", current_datetime=current_datetime + self.user_input_rewrite_agent = LLMAgentCreator.create( + "user_input_rewrite_agent", current_datetime=current_datetime ) self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs) @@ -49,7 +49,7 @@ def get_all_agents(self): self.answer_agent = LLMAgentCreator.create("answer_agent") agents = [ - self.question_rewrite_agent, + self.user_input_rewrite_agent, self.parallel_query_solving_agent, self.answer_agent, ] @@ -73,11 +73,11 @@ def unified_selector(self, messages): current_agent = messages[-1].source if messages else "user" decision = None - # If this is the first message start with question_rewrite_agent + # If this is the first message start with user_input_rewrite_agent if current_agent == "user": - decision = "question_rewrite_agent" + decision = "user_input_rewrite_agent" # Handle transition after query rewriting - elif current_agent == "question_rewrite_agent": + elif current_agent == "user_input_rewrite_agent": decision = "parallel_query_solving_agent" # Handle transition after parallel query solving elif current_agent == "parallel_query_solving_agent": @@ -145,22 +145,24 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: except json.JSONDecodeError: logging.warning("Unable to read SQL query results: %s", sql_query_results) sql_query_results = {} - sub_question_results = {} + sub_user_input_results = {} else: - # Only load sub-question results if we have a database result - sub_question_results = self.parse_message_content(messages[1].content) - logging.info("Sub-Question Results: %s", sub_question_results) + # Only load sub-user_input results if we have a database result + sub_user_input_results = self.parse_message_content(messages[1].content) + logging.info("Sub-user_input Results: %s", sub_user_input_results) try: - sub_questions = [ - sub_question - for sub_question_group in sub_question_results.get("sub_questions", []) - for sub_question in sub_question_group + sub_user_inputs = [ + sub_user_input + for sub_user_input_group in sub_user_input_results.get( + "sub_user_inputs", [] + ) + for sub_user_input in sub_user_input_group ] logging.info("SQL Query Results: %s", sql_query_results) payload = AnswerWithSourcesPayload( - answer=answer, sub_questions=sub_questions + answer=answer, sub_user_inputs=sub_user_inputs ) if not isinstance(sql_query_results, dict): @@ -171,9 +173,11 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: logging.error("No 'results' key in sql_query_results") return payload - for question, sql_query_result_list in sql_query_results["results"].items(): + for user_input, sql_query_result_list in sql_query_results[ + "results" + ].items(): if not sql_query_result_list: # Check if list is empty - logging.warning(f"No results for question: {question}") + logging.warning(f"No results for user_input: {user_input}") continue for sql_query_result in sql_query_result_list: @@ -209,16 +213,16 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: answer=f"{answer}\nError processing results: {str(e)}" ) - async def process_question( + async def process_user_input( self, - question_payload: UserInputPayload, + user_input_payload: UserInputPayload, chat_history: list[InteractionPayload] = None, ) -> AsyncGenerator[InteractionPayload, None]: - """Process the complete question through the unified system. + """Process the complete user_input through the unified system. Args: ---- - task (str): The user question to process. + task (str): The user user_input to process. chat_history (list[str], optional): The chat history. Defaults to None. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. @@ -226,13 +230,13 @@ async def process_question( ------- dict: The response from the system. """ - logging.info("Processing question: %s", question_payload.body.question) + logging.info("Processing user_input: %s", user_input_payload.body.user_input) logging.info("Chat history: %s", chat_history) agent_input = { - "question": question_payload.body.question, + "user_input": user_input_payload.body.user_input, "chat_history": {}, - "injected_parameters": question_payload.body.injected_parameters, + "injected_parameters": user_input_payload.body.injected_parameters, } if chat_history is not None: @@ -241,7 +245,7 @@ async def process_question( if chat.root.payload_type == PayloadType.USER_INPUT: # For now only consider the user query chat_history_key = f"chat_{idx}" - agent_input[chat_history_key] = chat.root.body.question + agent_input[chat_history_key] = chat.root.body.user_input async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -249,7 +253,7 @@ async def process_question( payload = None if isinstance(message, TextMessage): - if message.source == "question_rewrite_agent": + if message.source == "user_input_rewrite_agent": payload = ProcessingUpdatePayload( message="Rewriting the query...", ) @@ -272,10 +276,10 @@ async def process_question( elif message.messages[-1].source == "parallel_query_solving_agent": # Load into disambiguation request payload = self.extract_disambiguation_request(message.messages) - elif message.messages[-1].source == "question_rewrite_agent": + elif message.messages[-1].source == "user_input_rewrite_agent": # Load into empty response payload = AnswerWithSourcesPayload( - answer="Apologies, I cannot answer that question as it is not relevant. Please try another question or rephrase your current question." + answer="Apologies, I cannot answer that user_input as it is not relevant. Please try another user_input or rephrase your current user_input." ) else: logging.error("Unexpected TaskResult: %s", message) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index 5a26c7a..886310d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -43,7 +43,7 @@ def get_tool(cls, sql_helper, tool_name: str): elif tool_name == "sql_get_entity_schemas_tool": return FunctionToolAlias( sql_helper.get_entity_schemas, - description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.", + description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user input and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.", ) elif tool_name == "sql_get_column_values_tool": return FunctionToolAlias( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index e0b0888..fae9a73 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -84,9 +84,9 @@ async def on_messages_stream( injected_parameters = {} # Load the json of the last message to populate the final output object - question_rewrites = json.loads(last_response) + user_input_rewrites = json.loads(last_response) - logging.info(f"Query Rewrites: {question_rewrites}") + logging.info(f"Query Rewrites: {user_input_rewrites}") async def consume_inner_messages_from_agentic_flow( agentic_flow, identifier, database_results @@ -143,7 +143,7 @@ async def consume_inner_messages_from_agentic_flow( ): logging.info("Contains pre-run results") for pre_run_sql_query, pre_run_result in parsed_message[ - "cached_questions_and_schemas" + "cached_user_inputs_and_schemas" ].items(): database_results[identifier].append( { @@ -164,7 +164,7 @@ async def consume_inner_messages_from_agentic_flow( # Convert all_non_database_query to lowercase string and compare all_non_database_query = str( - question_rewrites.get("all_non_database_query", "false") + user_input_rewrites.get("all_non_database_query", "false") ).lower() if all_non_database_query == "true": @@ -177,12 +177,12 @@ async def consume_inner_messages_from_agentic_flow( return # Start processing sub-queries - for question_rewrite in question_rewrites["sub_questions"]: - logging.info(f"Processing sub-query: {question_rewrite}") + for user_input_rewrite in user_input_rewrites["sub_user_inputs"]: + logging.info(f"Processing sub-query: {user_input_rewrite}") # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs) - identifier = ", ".join(question_rewrite) + identifier = ", ".join(user_input_rewrite) # Add database connection info to injected parameters query_params = injected_parameters.copy() if injected_parameters else {} @@ -196,8 +196,8 @@ async def consume_inner_messages_from_agentic_flow( # Launch tasks for each sub-query inner_solving_generators.append( consume_inner_messages_from_agentic_flow( - inner_autogen_text_2_sql.process_question( - question=question_rewrite, + inner_autogen_text_2_sql.process_user_input( + user_input=user_input_rewrite, injected_parameters=query_params, ), identifier, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index b7d0072..9ee16bf 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -17,7 +17,7 @@ class SqlQueryCacheAgent(BaseChatAgent): def __init__(self): super().__init__( "sql_query_cache_agent", - "An agent that fetches the queries from the cache based on the user question.", + "An agent that fetches the queries from the cache based on the user user_input.", ) self.agent = SqlQueryCacheAgentCustomAgent() @@ -40,19 +40,19 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - # Get the decomposed questions from the question_rewrite_agent + # Get the decomposed user_inputs from the user_input_rewrite_agent try: request_details = json.loads(messages[0].content) injected_parameters = request_details["injected_parameters"] - user_questions = request_details["question"] - logging.info(f"Processing questions: {user_questions}") + user_user_inputs = request_details["user_input"] + logging.info(f"Processing user_inputs: {user_user_inputs}") logging.info(f"Input Parameters: {injected_parameters}") except json.JSONDecodeError: - # If not JSON array, process as single question + # If not JSON array, process as single user_input raise ValueError("Could not load message") cached_results = await self.agent.process_message( - user_questions, injected_parameters + user_user_inputs, injected_parameters ) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index c33219b..1c5b288 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -17,7 +17,7 @@ class SqlSchemaSelectionAgent(BaseChatAgent): def __init__(self, **kwargs): super().__init__( "sql_schema_selection_agent", - "An agent that fetches the schemas from the cache based on the user question.", + "An agent that fetches the schemas from the cache based on the user input.", ) self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs) @@ -43,19 +43,19 @@ async def on_messages_stream( # Try to parse as JSON first try: request_details = json.loads(messages[0].content) - user_questions = request_details["question"] + user_inputs = request_details["question"] except (json.JSONDecodeError, KeyError): # If not JSON or missing question key, use content directly - user_questions = messages[0].content + user_inputs = messages[0].content - if isinstance(user_questions, str): - user_questions = [user_questions] - elif not isinstance(user_questions, list): - user_questions = [str(user_questions)] + if isinstance(user_inputs, str): + user_inputs = [user_inputs] + elif not isinstance(user_inputs, list): + user_inputs = [str(user_inputs)] - logging.info(f"Processing questions: {user_questions}") + logging.info(f"Processing questions: {user_inputs}") - final_results = await self.agent.process_message(user_questions) + final_results = await self.agent.process_message(user_inputs) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 5b48ab6..d73e482 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -178,7 +178,7 @@ def process_question( Args: ---- - task (str): The user question to process. + task (str): The user input to process. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py index 958d1a4..c857cdc 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py @@ -9,7 +9,7 @@ def __init__(self): self.sql_connector = ConnectorFactory.get_database_connector() async def process_message( - self, user_questions: list[str], injected_parameters: dict + self, user_inputs: list[str], injected_parameters: dict ) -> dict: # Initialize results dictionary cached_results = { @@ -18,7 +18,7 @@ async def process_message( } # Process each question sequentially - for question in user_questions: + for question in user_inputs: # Fetch the queries from the cache based on the question logging.info(f"Fetching queries from cache for question: {question}") cached_query = await self.sql_connector.fetch_queries_from_cache( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py index a2e0293..7ee28ff 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py @@ -22,15 +22,15 @@ def __init__(self, **kwargs): self.system_prompt = Template(system_prompt).render(kwargs) - async def process_message(self, user_questions: list[str]) -> dict: - logging.info(f"User questions: {user_questions}") + async def process_message(self, user_inputs: list[str]) -> dict: + logging.info(f"user inputs: {user_inputs}") entity_tasks = [] - for user_question in user_questions: + for user_input in user_inputs: messages = [ {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": user_question}, + {"role": "user", "content": user_input}, ] entity_tasks.append( self.open_ai_connector.run_completion_request( @@ -47,7 +47,7 @@ async def process_message(self, user_questions: list[str]) -> dict: logging.info(f"Entity result: {entity_result}") for entity_group in entity_result.entities: - logging.info(f"Searching for schemas for entity group: {entity_group}") + logging.info("Searching for schemas for entity group: %s", entity_group) entity_search_tasks.append( self.sql_connector.get_entity_schemas( " ".join(entity_group), as_json=False @@ -56,7 +56,7 @@ async def process_message(self, user_questions: list[str]) -> dict: for filter_condition in entity_result.filter_conditions: logging.info( - f"Searching for column values for filter: {filter_condition}" + "Searching for column values for filter: %s", filter_condition ) column_search_tasks.append( self.sql_connector.get_column_values( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_rewrite_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_input_rewrite_agent.yaml similarity index 97% rename from text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_rewrite_agent.yaml rename to text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_input_rewrite_agent.yaml index 134e0ec..23397e0 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_rewrite_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_input_rewrite_agent.yaml @@ -1,5 +1,5 @@ model: "4o-mini" -description: "An agent that preprocesses user questions by decomposing complex queries into simpler sub-queries that can be processed independently and then combined." +description: "An agent that preprocesses user inputs by decomposing complex queries into simpler sub-queries that can be processed independently and then combined." system_message: | You are a helpful AI Assistant specializing in breaking down complex questions into simpler sub-queries that can be processed independently and then combined for the final answer. You should identify when a question can be solved through simpler sub-queries and provide clear instructions for combining their results. From 00b332d88d9c9364600eac9ff2eda03cc4f4d591 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 9 Jan 2025 17:07:01 +0000 Subject: [PATCH 04/10] Use value and display --- .../payloads/interaction_payloads.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 3ff36b4..09f11af 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -32,17 +32,16 @@ class PayloadType(StrEnum): USER_INPUT = "user_input" -class ColumnFilterPair(BaseModel): - fqn: str - column: str - filter_value: str | None = Field(default=None) +class Choice(BaseModel): + value: str + display: str class DismabiguationRequestsPayload(PayloadBase): class Body(BaseModel): class DismabiguationRequest(BaseModel): question: str - choices: list[ColumnFilterPair] | None = Field(default=None) + choices: list[Choice] | None = Field(default=None) disambiguation_requests: list[DismabiguationRequest] @@ -63,25 +62,33 @@ def __init__(self, **kwargs): { "question": "Which of the following do you mean?", "choices": [ - {"fqn": "", "column": "product_name", "filter_value": "Road Bike"}, { - "fqn": "", - "column": "product_name", - "filter_value": "Mountain Bike", + "value": "product_name - Road Bike", + "display": "Road Bike", + }, + { + "value": "product_name - Mountain Bike", + "display": "Mountain Bike", }, ], }, { "question": "Do you mean total sales by volume or number of customers?", "choices": [ - {"fqn": "", "column": "sales_volume", "filter_value": None}, - {"fqn": "", "column": "customer_count", "filter_value": None}, + { + "value": "total_sales - volume", + "display": "Total Sales by Volume", + }, + { + "value": "total_sales - customers", + "display": "Number of Customers", + }, ], }, ] ) -print(request.model_dump()) +print(request.model_dump_json()) class AnswerWithSourcesPayload(PayloadBase): From f46d9bdae876dd7e8ca9a0f8f5e63cf1575aa2a2 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 13 Jan 2025 09:52:31 +0000 Subject: [PATCH 05/10] Update payload --- .../autogen_text_2_sql/autogen_text_2_sql.py | 62 +++++---- .../parallel_query_solving_agent.py | 18 +-- .../custom_agents/sql_query_cache_agent.py | 12 +- .../sql_schema_selection_agent.py | 16 +-- .../custom_agents/sql_query_cache_agent.py | 4 +- .../sql_schema_selection_agent.py | 8 +- .../payloads/interaction_payloads.py | 125 ++++++++---------- 7 files changed, 114 insertions(+), 131 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 9b3abb4..8885af2 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -40,8 +40,8 @@ def get_all_agents(self): # Get current datetime for the Query Rewrite Agent current_datetime = datetime.now() - self.user_input_rewrite_agent = LLMAgentCreator.create( - "user_input_rewrite_agent", current_datetime=current_datetime + self.message_rewrite_agent = LLMAgentCreator.create( + "message_rewrite_agent", current_datetime=current_datetime ) self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs) @@ -49,7 +49,7 @@ def get_all_agents(self): self.answer_agent = LLMAgentCreator.create("answer_agent") agents = [ - self.user_input_rewrite_agent, + self.message_rewrite_agent, self.parallel_query_solving_agent, self.answer_agent, ] @@ -73,11 +73,11 @@ def unified_selector(self, messages): current_agent = messages[-1].source if messages else "user" decision = None - # If this is the first message start with user_input_rewrite_agent + # If this is the first message start with message_rewrite_agent if current_agent == "user": - decision = "user_input_rewrite_agent" + decision = "message_rewrite_agent" # Handle transition after query rewriting - elif current_agent == "user_input_rewrite_agent": + elif current_agent == "message_rewrite_agent": decision = "parallel_query_solving_agent" # Handle transition after parallel query solving elif current_agent == "parallel_query_solving_agent": @@ -145,24 +145,24 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: except json.JSONDecodeError: logging.warning("Unable to read SQL query results: %s", sql_query_results) sql_query_results = {} - sub_user_input_results = {} + sub_message_results = {} else: - # Only load sub-user_input results if we have a database result - sub_user_input_results = self.parse_message_content(messages[1].content) - logging.info("Sub-user_input Results: %s", sub_user_input_results) + # Only load sub-message results if we have a database result + sub_message_results = self.parse_message_content(messages[1].content) + logging.info("Sub-message Results: %s", sub_message_results) try: - sub_user_inputs = [ - sub_user_input - for sub_user_input_group in sub_user_input_results.get( - "sub_user_inputs", [] + decomposed_messages = [ + sub_message + for sub_message_group in sub_message_results.get( + "decomposed_messages", [] ) - for sub_user_input in sub_user_input_group + for sub_message in sub_message_group ] logging.info("SQL Query Results: %s", sql_query_results) payload = AnswerWithSourcesPayload( - answer=answer, sub_user_inputs=sub_user_inputs + answer=answer, decomposed_messages=decomposed_messages ) if not isinstance(sql_query_results, dict): @@ -173,11 +173,9 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: logging.error("No 'results' key in sql_query_results") return payload - for user_input, sql_query_result_list in sql_query_results[ - "results" - ].items(): + for message, sql_query_result_list in sql_query_results["results"].items(): if not sql_query_result_list: # Check if list is empty - logging.warning(f"No results for user_input: {user_input}") + logging.warning(f"No results for message: {message}") continue for sql_query_result in sql_query_result_list: @@ -213,16 +211,16 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: answer=f"{answer}\nError processing results: {str(e)}" ) - async def process_user_input( + async def process_message( self, - user_input_payload: UserInputPayload, + message_payload: UserInputPayload, chat_history: list[InteractionPayload] = None, ) -> AsyncGenerator[InteractionPayload, None]: - """Process the complete user_input through the unified system. + """Process the complete message through the unified system. Args: ---- - task (str): The user user_input to process. + task (str): The user message to process. chat_history (list[str], optional): The chat history. Defaults to None. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. @@ -230,22 +228,22 @@ async def process_user_input( ------- dict: The response from the system. """ - logging.info("Processing user_input: %s", user_input_payload.body.user_input) + logging.info("Processing message: %s", message_payload.body.message) logging.info("Chat history: %s", chat_history) agent_input = { - "user_input": user_input_payload.body.user_input, + "message": message_payload.body.message, "chat_history": {}, - "injected_parameters": user_input_payload.body.injected_parameters, + "injected_parameters": message_payload.body.injected_parameters, } if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): - if chat.root.payload_type == PayloadType.USER_INPUT: + if chat.root.payload_type == PayloadType.message: # For now only consider the user query chat_history_key = f"chat_{idx}" - agent_input[chat_history_key] = chat.root.body.user_input + agent_input[chat_history_key] = chat.root.body.message async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -253,7 +251,7 @@ async def process_user_input( payload = None if isinstance(message, TextMessage): - if message.source == "user_input_rewrite_agent": + if message.source == "message_rewrite_agent": payload = ProcessingUpdatePayload( message="Rewriting the query...", ) @@ -276,10 +274,10 @@ async def process_user_input( elif message.messages[-1].source == "parallel_query_solving_agent": # Load into disambiguation request payload = self.extract_disambiguation_request(message.messages) - elif message.messages[-1].source == "user_input_rewrite_agent": + elif message.messages[-1].source == "message_rewrite_agent": # Load into empty response payload = AnswerWithSourcesPayload( - answer="Apologies, I cannot answer that user_input as it is not relevant. Please try another user_input or rephrase your current user_input." + answer="Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message." ) else: logging.error("Unexpected TaskResult: %s", message) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index fae9a73..7444068 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -84,9 +84,9 @@ async def on_messages_stream( injected_parameters = {} # Load the json of the last message to populate the final output object - user_input_rewrites = json.loads(last_response) + message_rewrites = json.loads(last_response) - logging.info(f"Query Rewrites: {user_input_rewrites}") + logging.info(f"Query Rewrites: {message_rewrites}") async def consume_inner_messages_from_agentic_flow( agentic_flow, identifier, database_results @@ -143,7 +143,7 @@ async def consume_inner_messages_from_agentic_flow( ): logging.info("Contains pre-run results") for pre_run_sql_query, pre_run_result in parsed_message[ - "cached_user_inputs_and_schemas" + "cached_messages_and_schemas" ].items(): database_results[identifier].append( { @@ -164,7 +164,7 @@ async def consume_inner_messages_from_agentic_flow( # Convert all_non_database_query to lowercase string and compare all_non_database_query = str( - user_input_rewrites.get("all_non_database_query", "false") + message_rewrites.get("all_non_database_query", "false") ).lower() if all_non_database_query == "true": @@ -177,12 +177,12 @@ async def consume_inner_messages_from_agentic_flow( return # Start processing sub-queries - for user_input_rewrite in user_input_rewrites["sub_user_inputs"]: - logging.info(f"Processing sub-query: {user_input_rewrite}") + for message_rewrite in message_rewrites["decomposed_messages"]: + logging.info(f"Processing sub-query: {message_rewrite}") # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs) - identifier = ", ".join(user_input_rewrite) + identifier = ", ".join(message_rewrite) # Add database connection info to injected parameters query_params = injected_parameters.copy() if injected_parameters else {} @@ -196,8 +196,8 @@ async def consume_inner_messages_from_agentic_flow( # Launch tasks for each sub-query inner_solving_generators.append( consume_inner_messages_from_agentic_flow( - inner_autogen_text_2_sql.process_user_input( - user_input=user_input_rewrite, + inner_autogen_text_2_sql.process_message( + message=message_rewrite, injected_parameters=query_params, ), identifier, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index 9ee16bf..f3f138e 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -17,7 +17,7 @@ class SqlQueryCacheAgent(BaseChatAgent): def __init__(self): super().__init__( "sql_query_cache_agent", - "An agent that fetches the queries from the cache based on the user user_input.", + "An agent that fetches the queries from the cache based on the user message.", ) self.agent = SqlQueryCacheAgentCustomAgent() @@ -40,19 +40,19 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - # Get the decomposed user_inputs from the user_input_rewrite_agent + # Get the decomposed messages from the message_rewrite_agent try: request_details = json.loads(messages[0].content) injected_parameters = request_details["injected_parameters"] - user_user_inputs = request_details["user_input"] - logging.info(f"Processing user_inputs: {user_user_inputs}") + user_messages = request_details["message"] + logging.info(f"Processing messages: {user_messages}") logging.info(f"Input Parameters: {injected_parameters}") except json.JSONDecodeError: - # If not JSON array, process as single user_input + # If not JSON array, process as single message raise ValueError("Could not load message") cached_results = await self.agent.process_message( - user_user_inputs, injected_parameters + user_messages, injected_parameters ) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index 1c5b288..557d7a5 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -43,19 +43,19 @@ async def on_messages_stream( # Try to parse as JSON first try: request_details = json.loads(messages[0].content) - user_inputs = request_details["question"] + messages = request_details["question"] except (json.JSONDecodeError, KeyError): # If not JSON or missing question key, use content directly - user_inputs = messages[0].content + messages = messages[0].content - if isinstance(user_inputs, str): - user_inputs = [user_inputs] - elif not isinstance(user_inputs, list): - user_inputs = [str(user_inputs)] + if isinstance(messages, str): + messages = [messages] + elif not isinstance(messages, list): + messages = [str(messages)] - logging.info(f"Processing questions: {user_inputs}") + logging.info(f"Processing questions: {messages}") - final_results = await self.agent.process_message(user_inputs) + final_results = await self.agent.process_message(messages) yield Response( chat_message=TextMessage( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py index c857cdc..d5e490f 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py @@ -9,7 +9,7 @@ def __init__(self): self.sql_connector = ConnectorFactory.get_database_connector() async def process_message( - self, user_inputs: list[str], injected_parameters: dict + self, messages: list[str], injected_parameters: dict ) -> dict: # Initialize results dictionary cached_results = { @@ -18,7 +18,7 @@ async def process_message( } # Process each question sequentially - for question in user_inputs: + for question in messages: # Fetch the queries from the cache based on the question logging.info(f"Fetching queries from cache for question: {question}") cached_query = await self.sql_connector.fetch_queries_from_cache( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py index 7ee28ff..adfba3b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py @@ -22,15 +22,15 @@ def __init__(self, **kwargs): self.system_prompt = Template(system_prompt).render(kwargs) - async def process_message(self, user_inputs: list[str]) -> dict: - logging.info(f"user inputs: {user_inputs}") + async def process_message(self, messages: list[str]) -> dict: + logging.info(f"user inputs: {messages}") entity_tasks = [] - for user_input in user_inputs: + for message in messages: messages = [ {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": user_input}, + {"role": "user", "content": message}, ] entity_tasks.append( self.open_ai_connector.run_completion_request( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 09f11af..1ca8c99 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -8,18 +8,6 @@ from uuid import uuid4 -class PayloadBase(BaseModel): - prompt_tokens: int | None = None - completion_tokens: int | None = None - message_id: str = Field(..., default_factory=lambda: str(uuid4())) - timestamp: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - description="Timestamp in UTC", - ) - payload_type: str - payload_source: str - - class PayloadSource(StrEnum): USER = "user" AGENT = "agent" @@ -27,28 +15,45 @@ class PayloadSource(StrEnum): class PayloadType(StrEnum): ANSWER_WITH_SOURCES = "answer_with_sources" - DISAMBIGUATION_REQUEST = "disambiguation_request" + DISAMBIGUATION_REQUESTS = "disambiguation_requests" PROCESSING_UPDATE = "processing_update" USER_INPUT = "user_input" -class Choice(BaseModel): - value: str - display: str +class PayloadBase(BaseModel): + prompt_tokens: int | None = Field( + None, description="Number of tokens in the prompt", alias="promptTokens" + ) + completion_tokens: int | None = Field( + None, description="Number of tokens in the completion", alias="completionTokens" + ) + message_id: str = Field( + ..., default_factory=lambda: str(uuid4()), alias="messageId" + ) + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Timestamp in UTC", + ) + payload_type: PayloadType = Field(..., alias="payloadType") + payload_source: PayloadSource = Field(..., alias="payloadSource") class DismabiguationRequestsPayload(PayloadBase): class Body(BaseModel): class DismabiguationRequest(BaseModel): - question: str - choices: list[Choice] | None = Field(default=None) + agent_question: str | None = Field(..., alias="agentQuestion") + user_choices: list[str] | None = Field(default=None, alias="userChoices") - disambiguation_requests: list[DismabiguationRequest] + disambiguation_requests: list[DismabiguationRequest] = Field( + alias="disambiguationRequests" + ) - payload_type: Literal[ - PayloadType.DISAMBIGUATION_REQUEST - ] = PayloadType.DISAMBIGUATION_REQUEST - payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field( + PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType" + ) + payload_source: Literal[PayloadSource.AGENT] = Field( + default=PayloadSource.AGENT, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): @@ -57,54 +62,24 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -request = DismabiguationRequestsPayload( - disambiguation_requests=[ - { - "question": "Which of the following do you mean?", - "choices": [ - { - "value": "product_name - Road Bike", - "display": "Road Bike", - }, - { - "value": "product_name - Mountain Bike", - "display": "Mountain Bike", - }, - ], - }, - { - "question": "Do you mean total sales by volume or number of customers?", - "choices": [ - { - "value": "total_sales - volume", - "display": "Total Sales by Volume", - }, - { - "value": "total_sales - customers", - "display": "Number of Customers", - }, - ], - }, - ] -) - -print(request.model_dump_json()) - - class AnswerWithSourcesPayload(PayloadBase): class Body(BaseModel): class Source(BaseModel): - sql_query: str - sql_rows: list[dict] + sql_query: str = Field(alias="sqlQuery") + sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows") answer: str - sub_questions: list[str] = Field(default_factory=list) + decomposed_user_messages: list[str] = Field( + default_factory=list, alias="decomposedUserMessages" + ) sources: list[Source] = Field(default_factory=list) - payload_type: Literal[ - PayloadType.ANSWER_WITH_SOURCES - ] = PayloadType.ANSWER_WITH_SOURCES - payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field( + PayloadType.ANSWER_WITH_SOURCES, alias="payloadType" + ) + payload_source: Literal[PayloadSource.AGENT] = Field( + PayloadSource.AGENT, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): @@ -118,8 +93,12 @@ class Body(BaseModel): title: str | None = "Processing..." message: str | None = "Processing..." - payload_type: Literal[PayloadType.PROCESSING_UPDATE] = PayloadType.PROCESSING_UPDATE - payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + payload_type: Literal[PayloadType.PROCESSING_UPDATE] = Field( + PayloadType.PROCESSING_UPDATE, alias="payloadType" + ) + payload_source: Literal[PayloadSource.AGENT] = Field( + PayloadSource.AGENT, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): @@ -130,8 +109,10 @@ def __init__(self, **kwargs): class UserInputPayload(PayloadBase): class Body(BaseModel): - question: str - injected_parameters: dict = Field(default_factory=dict) + user_message: str = Field(..., alias="userMessage") + injected_parameters: dict = Field( + default_factory=dict, alias="injectedParameters" + ) @model_validator(mode="before") def add_defaults(cls, values): @@ -145,8 +126,12 @@ def add_defaults(cls, values): values["injected_parameters"] = {**defaults, **injected} return values - payload_type: Literal[PayloadType.USER_INPUT] = PayloadType.USER_INPUT - payload_source: Literal[PayloadSource.USER] = PayloadSource.USER + payload_type: Literal[PayloadType.USER_INPUT] = Field( + PayloadType.USER_INPUT, alias="payloadType" + ) + payload_source: Literal[PayloadSource.USER] = Field( + PayloadSource.USER, alias="payloadSource" + ) body: Body | None = Field(default=None) def __init__(self, **kwargs): From 8b882d86d62252059d328083d983932429a3ef14 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 13 Jan 2025 09:58:06 +0000 Subject: [PATCH 06/10] Add model config --- .../payloads/interaction_payloads.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 1ca8c99..fc9a108 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pydantic import BaseModel, RootModel, Field, model_validator +from pydantic import BaseModel, RootModel, Field, model_validator, ConfigDict from enum import StrEnum from typing import Literal @@ -20,7 +20,11 @@ class PayloadType(StrEnum): USER_INPUT = "user_input" -class PayloadBase(BaseModel): +class InteractionPayloadBase(BaseModel): + model_config = ConfigDict(allow_population_by_field_name=True, extra="ignore") + + +class PayloadBase(InteractionPayloadBase): prompt_tokens: int | None = Field( None, description="Number of tokens in the prompt", alias="promptTokens" ) @@ -38,9 +42,9 @@ class PayloadBase(BaseModel): payload_source: PayloadSource = Field(..., alias="payloadSource") -class DismabiguationRequestsPayload(PayloadBase): - class Body(BaseModel): - class DismabiguationRequest(BaseModel): +class DismabiguationRequestsPayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): + class DismabiguationRequest(InteractionPayloadBase): agent_question: str | None = Field(..., alias="agentQuestion") user_choices: list[str] | None = Field(default=None, alias="userChoices") @@ -62,9 +66,9 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class AnswerWithSourcesPayload(PayloadBase): - class Body(BaseModel): - class Source(BaseModel): +class AnswerWithSourcesPayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): + class Source(InteractionPayloadBase): sql_query: str = Field(alias="sqlQuery") sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows") @@ -88,8 +92,8 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class ProcessingUpdatePayload(PayloadBase): - class Body(BaseModel): +class ProcessingUpdatePayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): title: str | None = "Processing..." message: str | None = "Processing..." @@ -107,8 +111,8 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class UserInputPayload(PayloadBase): - class Body(BaseModel): +class UserInputPayload(InteractionPayloadBase): + class Body(InteractionPayloadBase): user_message: str = Field(..., alias="userMessage") injected_parameters: dict = Field( default_factory=dict, alias="injectedParameters" From a3e444cd9742e8625df266786fffee8fde83c85e Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 13 Jan 2025 10:41:33 +0000 Subject: [PATCH 07/10] Update entry points --- ...on 5 - Agentic Vector Based Text2SQL.ipynb | 4 +- .../autogen/evaluate_autogen_text2sql.ipynb | 4 +- .../src/autogen_text_2_sql/__init__.py | 4 +- .../autogen_text_2_sql/autogen_text_2_sql.py | 56 +++++++++---------- .../autogen_text_2_sql/evaluation_utils.py | 2 + .../inner_autogen_text_2_sql.py | 9 ++- .../payloads/interaction_payloads.py | 15 +++-- 7 files changed, 49 insertions(+), 45 deletions(-) diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 05ce703..edcf12b 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -50,7 +50,7 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload" + "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload" ] }, { @@ -100,7 +100,7 @@ "metadata": {}, "outputs": [], "source": [ - "async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n", + "async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"What is the total number of sales?\")):\n", " logging.info(\"Received %s Message from Text2SQL System\", message)" ] }, diff --git a/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb b/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb index 8b90716..2b38a09 100644 --- a/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb +++ b/text_2_sql/autogen/evaluate_autogen_text2sql.ipynb @@ -68,7 +68,7 @@ "# Add the src directory to the path\n", "sys.path.append(str(notebook_dir / \"src\"))\n", "\n", - "from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload\n", + "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n", "from autogen_text_2_sql.evaluation_utils import get_final_sql_query\n", "\n", "# Configure logging\n", @@ -127,7 +127,7 @@ " all_queries = []\n", " final_query = None\n", " \n", - " async for message in autogen_text2sql.process_question(QuestionPayload(question=question)):\n", + " async for message in autogen_text2sql.process_user_message(UserMessagePayload(user_message=question)):\n", " if message.payload_type == \"answer_with_sources\":\n", " # Extract from results\n", " if hasattr(message.body, 'results'):\n", diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index 942e106..bf72f7d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql -from text_2_sql_core.payloads.interaction_payloads import UserInputPayload +from text_2_sql_core.payloads.interaction_payloads import UserMessagePayload -__all__ = ["AutoGenText2Sql", "UserInputPayload"] +__all__ = ["AutoGenText2Sql", "UserMessagePayload"] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 8885af2..f4c7b78 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -19,7 +19,7 @@ import re from text_2_sql_core.payloads.interaction_payloads import ( - UserInputPayload, + UserMessagePayload, AnswerWithSourcesPayload, DismabiguationRequestsPayload, ProcessingUpdatePayload, @@ -102,15 +102,6 @@ def agentic_flow(self): ) return flow - def extract_disambiguation_request( - self, messages: list - ) -> DismabiguationRequestsPayload: - """Extract the disambiguation request from the answer.""" - disambiguation_request = messages[-1].content - return DismabiguationRequestsPayload( - disambiguation_request=disambiguation_request, - ) - def parse_message_content(self, content): """Parse different message content formats into a dictionary.""" if isinstance(content, (list, dict)): @@ -134,6 +125,26 @@ def parse_message_content(self, content): # If all parsing attempts fail, return the content as-is return content + def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]: + """Extract the decomposed messages from the answer.""" + # Only load sub-message results if we have a database result + sub_message_results = self.parse_message_content(messages[1].content) + logging.info("Decomposed Results: %s", sub_message_results) + + return sub_message_results.get("decomposed_messages", []) + + def extract_disambiguation_request( + self, messages: list + ) -> DismabiguationRequestsPayload: + """Extract the disambiguation request from the answer.""" + disambiguation_request = messages[-1].content + + decomposed_user_messages = self.extract_decomposed_user_messages(messages) + return DismabiguationRequestsPayload( + disambiguation_request=disambiguation_request, + decomposed_user_messages=decomposed_user_messages, + ) + def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: """Extract the sources from the answer.""" answer = messages[-1].content @@ -145,24 +156,13 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: except json.JSONDecodeError: logging.warning("Unable to read SQL query results: %s", sql_query_results) sql_query_results = {} - sub_message_results = {} - else: - # Only load sub-message results if we have a database result - sub_message_results = self.parse_message_content(messages[1].content) - logging.info("Sub-message Results: %s", sub_message_results) try: - decomposed_messages = [ - sub_message - for sub_message_group in sub_message_results.get( - "decomposed_messages", [] - ) - for sub_message in sub_message_group - ] + decomposed_user_messages = self.extract_decomposed_user_messages(messages) logging.info("SQL Query Results: %s", sql_query_results) payload = AnswerWithSourcesPayload( - answer=answer, decomposed_messages=decomposed_messages + answer=answer, decomposed_user_messages=decomposed_user_messages ) if not isinstance(sql_query_results, dict): @@ -213,7 +213,7 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: async def process_message( self, - message_payload: UserInputPayload, + message_payload: UserMessagePayload, chat_history: list[InteractionPayload] = None, ) -> AsyncGenerator[InteractionPayload, None]: """Process the complete message through the unified system. @@ -228,11 +228,11 @@ async def process_message( ------- dict: The response from the system. """ - logging.info("Processing message: %s", message_payload.body.message) + logging.info("Processing message: %s", message_payload.body.user_message) logging.info("Chat history: %s", chat_history) agent_input = { - "message": message_payload.body.message, + "message": message_payload.body.user_message, "chat_history": {}, "injected_parameters": message_payload.body.injected_parameters, } @@ -240,10 +240,10 @@ async def process_message( if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): - if chat.root.payload_type == PayloadType.message: + if chat.root.payload_type == PayloadType.USER_MESSAGE: # For now only consider the user query chat_history_key = f"chat_{idx}" - agent_input[chat_history_key] = chat.root.body.message + agent_input[chat_history_key] = chat.root.body.user_message async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py b/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py index 4f2f158..938c07f 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import re from typing import Optional, List, Dict, Any diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index d73e482..a5ff6ac 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -169,9 +169,9 @@ def agentic_flow(self): ) return flow - def process_question( + def process_user_message( self, - question: str, + user_message: str, injected_parameters: dict = None, ): """Process the complete question through the unified system. @@ -185,15 +185,14 @@ def process_question( ------- dict: The response from the system. """ - logging.info("Processing question: %s", question) + logging.info("Processing question: %s", user_message) # Update environment with injected parameters self._update_environment(injected_parameters) try: agent_input = { - "question": question, - "chat_history": {}, + "user_message": user_message, "injected_parameters": injected_parameters, } diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index fc9a108..2fc0205 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -17,7 +17,7 @@ class PayloadType(StrEnum): ANSWER_WITH_SOURCES = "answer_with_sources" DISAMBIGUATION_REQUESTS = "disambiguation_requests" PROCESSING_UPDATE = "processing_update" - USER_INPUT = "user_input" + USER_MESSAGE = "user_message" class InteractionPayloadBase(BaseModel): @@ -51,6 +51,9 @@ class DismabiguationRequest(InteractionPayloadBase): disambiguation_requests: list[DismabiguationRequest] = Field( alias="disambiguationRequests" ) + decomposed_user_messages: list[list[str]] = Field( + default_factory=list, alias="decomposedUserMessages" + ) payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field( PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType" @@ -73,7 +76,7 @@ class Source(InteractionPayloadBase): sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows") answer: str - decomposed_user_messages: list[str] = Field( + decomposed_user_messages: list[list[str]] = Field( default_factory=list, alias="decomposedUserMessages" ) sources: list[Source] = Field(default_factory=list) @@ -111,7 +114,7 @@ def __init__(self, **kwargs): self.body = self.Body(**kwargs) -class UserInputPayload(InteractionPayloadBase): +class UserMessagePayload(InteractionPayloadBase): class Body(InteractionPayloadBase): user_message: str = Field(..., alias="userMessage") injected_parameters: dict = Field( @@ -130,8 +133,8 @@ def add_defaults(cls, values): values["injected_parameters"] = {**defaults, **injected} return values - payload_type: Literal[PayloadType.USER_INPUT] = Field( - PayloadType.USER_INPUT, alias="payloadType" + payload_type: Literal[PayloadType.USER_MESSAGE] = Field( + PayloadType.USER_MESSAGE, alias="payloadType" ) payload_source: Literal[PayloadSource.USER] = Field( PayloadSource.USER, alias="payloadSource" @@ -145,6 +148,6 @@ def __init__(self, **kwargs): class InteractionPayload(RootModel): - root: UserInputPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field( + root: UserMessagePayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field( ..., discriminator="payload_type" ) From fd0162290a9b3b63de723f7aaf4d33ab85e8b64a Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 13 Jan 2025 10:49:54 +0000 Subject: [PATCH 08/10] Update prompts and interactions --- .../autogen_text_2_sql/autogen_text_2_sql.py | 20 ++++----- .../parallel_query_solving_agent.py | 2 +- .../custom_agents/sql_query_cache_agent.py | 2 +- .../payloads/interaction_payloads.py | 2 +- ...t.yaml => user_message_rewrite_agent.yaml} | 41 +++++++++---------- 5 files changed, 33 insertions(+), 34 deletions(-) rename text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/{user_input_rewrite_agent.yaml => user_message_rewrite_agent.yaml} (83%) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index f4c7b78..35a669d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -40,8 +40,8 @@ def get_all_agents(self): # Get current datetime for the Query Rewrite Agent current_datetime = datetime.now() - self.message_rewrite_agent = LLMAgentCreator.create( - "message_rewrite_agent", current_datetime=current_datetime + self.user_message_rewrite_agent = LLMAgentCreator.create( + "user_message_rewrite_agent", current_datetime=current_datetime ) self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs) @@ -49,7 +49,7 @@ def get_all_agents(self): self.answer_agent = LLMAgentCreator.create("answer_agent") agents = [ - self.message_rewrite_agent, + self.user_message_rewrite_agent, self.parallel_query_solving_agent, self.answer_agent, ] @@ -73,11 +73,11 @@ def unified_selector(self, messages): current_agent = messages[-1].source if messages else "user" decision = None - # If this is the first message start with message_rewrite_agent + # If this is the first message start with user_message_rewrite_agent if current_agent == "user": - decision = "message_rewrite_agent" + decision = "user_message_rewrite_agent" # Handle transition after query rewriting - elif current_agent == "message_rewrite_agent": + elif current_agent == "user_message_rewrite_agent": decision = "parallel_query_solving_agent" # Handle transition after parallel query solving elif current_agent == "parallel_query_solving_agent": @@ -131,7 +131,7 @@ def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]: sub_message_results = self.parse_message_content(messages[1].content) logging.info("Decomposed Results: %s", sub_message_results) - return sub_message_results.get("decomposed_messages", []) + return sub_message_results.get("decomposed_user_messages", []) def extract_disambiguation_request( self, messages: list @@ -211,7 +211,7 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: answer=f"{answer}\nError processing results: {str(e)}" ) - async def process_message( + async def process_user_message( self, message_payload: UserMessagePayload, chat_history: list[InteractionPayload] = None, @@ -251,7 +251,7 @@ async def process_message( payload = None if isinstance(message, TextMessage): - if message.source == "message_rewrite_agent": + if message.source == "user_message_rewrite_agent": payload = ProcessingUpdatePayload( message="Rewriting the query...", ) @@ -274,7 +274,7 @@ async def process_message( elif message.messages[-1].source == "parallel_query_solving_agent": # Load into disambiguation request payload = self.extract_disambiguation_request(message.messages) - elif message.messages[-1].source == "message_rewrite_agent": + elif message.messages[-1].source == "user_message_rewrite_agent": # Load into empty response payload = AnswerWithSourcesPayload( answer="Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message." diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 7444068..cf21497 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -177,7 +177,7 @@ async def consume_inner_messages_from_agentic_flow( return # Start processing sub-queries - for message_rewrite in message_rewrites["decomposed_messages"]: + for message_rewrite in message_rewrites["decomposed_user_messages"]: logging.info(f"Processing sub-query: {message_rewrite}") # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index f3f138e..211c21c 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -40,7 +40,7 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - # Get the decomposed messages from the message_rewrite_agent + # Get the decomposed messages from the user_message_rewrite_agent try: request_details = json.loads(messages[0].content) injected_parameters = request_details["injected_parameters"] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 2fc0205..9fb047c 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -21,7 +21,7 @@ class PayloadType(StrEnum): class InteractionPayloadBase(BaseModel): - model_config = ConfigDict(allow_population_by_field_name=True, extra="ignore") + model_config = ConfigDict(populate_by_name=True, extra="ignore") class PayloadBase(InteractionPayloadBase): diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_input_rewrite_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml similarity index 83% rename from text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_input_rewrite_agent.yaml rename to text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml index 23397e0..5f0e977 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_input_rewrite_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml @@ -1,8 +1,8 @@ model: "4o-mini" -description: "An agent that preprocesses user inputs by decomposing complex queries into simpler sub-queries that can be processed independently and then combined." +description: "An agent that preprocesses user inputs by decomposing complex queries into simpler sub-messages that can be processed independently and then combined." system_message: | - You are a helpful AI Assistant specializing in breaking down complex questions into simpler sub-queries that can be processed independently and then combined for the final answer. You should identify when a question can be solved through simpler sub-queries and provide clear instructions for combining their results. + You are a helpful AI Assistant specializing in breaking down complex questions into simpler sub-messages that can be processed independently and then combined for the final answer. You should identify when a question can be solved through simpler sub-messages and provide clear instructions for combining their results. @@ -36,7 +36,7 @@ system_message: | 1. Question Filtering and Classification - Use the provided list of topics to filter out malicious or unrelated queries. - Ensure the question is relevant to the system's use case. - - If the question cannot be filtered, output an empty sub-query list in the JSON format. Followed by TERMINATE. + - If the question cannot be filtered, output an empty sub-message list in the JSON format. Followed by TERMINATE. - For non-database questions like greetings (e.g., "Hello", "What can you do?", "How are you?"), set "all_non_database_query" to true. - For questions about data (e.g., queries about records, counts, values, comparisons, or any questions that would require database access), set "all_non_database_query" to false. @@ -51,28 +51,27 @@ system_message: | - Determine if breaking down would simplify processing 4. Break Down Complex Queries: - - Create independent sub-queries that can be processed separately. - - Each sub-query should be a simple, focused task. - - Group dependent sub-queries together for sequential processing. - - Ensure each sub-query is simple and focused + - Create independent sub-messages that can be processed separately. + - Each sub-message should be a simple, focused task. + - Group dependent sub-messages together for sequential processing. - Include clear combination instructions - - Preserve all necessary context in each sub-query + - Preserve all necessary context in each sub-message 5. Handle Date References: - Resolve relative dates using {{ current_datetime }} - Maintain consistent YYYY-MM-DD format - - Include date context in each sub-query + - Include date context in each sub-message 6. Maintain Query Context: - - Each sub-query should be self-contained + - Each sub-message should be self-contained - Include all necessary filtering conditions - Preserve business context 1. Always consider if a complex query can be broken down - 2. Make sub-queries as simple as possible + 2. Make sub-messages as simple as possible 3. Include clear instructions for combining results - 4. Preserve all necessary context in each sub-query + 4. Preserve all necessary context in each sub-message 5. Resolve any relative dates before decomposition @@ -90,11 +89,11 @@ system_message: | - Return a JSON object with sub-queries and combination instructions: + Return a JSON object with sub-messages and combination instructions: { - "sub_questions": [ - [""], - [""], + "decomposed_messages": [ + [""], + [""], ... ], "combination_logic": "", @@ -109,7 +108,7 @@ system_message: | Input: "Which product categories have shown consistent growth quarter over quarter in 2008, and what were their top selling items?" Output: { - "sub_questions": [ + "decomposed_messages": [ ["Calculate quarterly sales totals by product category for 2008", "For these categories, find their top selling products in 2008"] ], "combination_logic": "First identify growing categories from quarterly analysis, then find their best-selling products", @@ -121,7 +120,7 @@ system_message: | Input: "How many orders did we have in 2008?" Output: { - "sub_questions": [ + "decomposed_messages": [ ["How many orders did we have in 2008?"] ], "combination_logic": "Direct count query, no combination needed", @@ -133,12 +132,12 @@ system_message: | Input: "Compare the sales performance of our top 5 products in Europe versus North America, including their market share in each region" Output: { - "sub_questions": [ + "decomposed_messages": [ ["Get total sales by product in European countries"], ["Get total sales by product in North American countries"], ["Calculate total market size for each region", "Find top 5 products by sales in each region"], ], - "combination_logic": "First identify top products in each region, then calculate and compare their market shares. Questions that depend on the result of each sub-query are combined.", + "combination_logic": "First identify top products in each region, then calculate and compare their market shares. Questions that depend on the result of each sub-message are combined.", "query_type": "complex", "all_non_database_query": "false" } @@ -147,7 +146,7 @@ system_message: | Input: "Hello, what can you help me with?" Output: { - "sub_questions": [ + "decomposed_messages": [ ["What are your capabilities?"] ], "combination_logic": "Simple greeting and capability question", From b200d217f9f54d0eaa56e8e9c3f6d0f1a1ac7f63 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 13 Jan 2025 11:00:56 +0000 Subject: [PATCH 09/10] Update agents --- .../custom_agents/parallel_query_solving_agent.py | 4 ++-- .../custom_agents/sql_query_cache_agent.py | 2 +- .../custom_agents/sql_query_cache_agent.py | 6 +++--- .../text_2_sql_core/payloads/interaction_payloads.py | 6 ------ .../prompts/user_message_rewrite_agent.yaml | 10 +++++----- 5 files changed, 11 insertions(+), 17 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index cf21497..b48189b 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -196,8 +196,8 @@ async def consume_inner_messages_from_agentic_flow( # Launch tasks for each sub-query inner_solving_generators.append( consume_inner_messages_from_agentic_flow( - inner_autogen_text_2_sql.process_message( - message=message_rewrite, + inner_autogen_text_2_sql.process_user_message( + user_message=message_rewrite, injected_parameters=query_params, ), identifier, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index 211c21c..d17b886 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -44,7 +44,7 @@ async def on_messages_stream( try: request_details = json.loads(messages[0].content) injected_parameters = request_details["injected_parameters"] - user_messages = request_details["message"] + user_messages = request_details["user_message"] logging.info(f"Processing messages: {user_messages}") logging.info(f"Input Parameters: {injected_parameters}") except json.JSONDecodeError: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py index d5e490f..b58f9b5 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py @@ -18,11 +18,11 @@ async def process_message( } # Process each question sequentially - for question in messages: + for message in messages: # Fetch the queries from the cache based on the question - logging.info(f"Fetching queries from cache for question: {question}") + logging.info(f"Fetching queries from cache for question: {message}") cached_query = await self.sql_connector.fetch_queries_from_cache( - question, injected_parameters=injected_parameters + message, injected_parameters=injected_parameters ) # If any question has pre-run results, set the flag diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 9fb047c..2dd2b55 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -25,12 +25,6 @@ class InteractionPayloadBase(BaseModel): class PayloadBase(InteractionPayloadBase): - prompt_tokens: int | None = Field( - None, description="Number of tokens in the prompt", alias="promptTokens" - ) - completion_tokens: int | None = Field( - None, description="Number of tokens in the completion", alias="completionTokens" - ) message_id: str = Field( ..., default_factory=lambda: str(uuid4()), alias="messageId" ) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml index 5f0e977..a779d52 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml @@ -91,7 +91,7 @@ system_message: | Return a JSON object with sub-messages and combination instructions: { - "decomposed_messages": [ + "decomposed_user_messages": [ [""], [""], ... @@ -108,7 +108,7 @@ system_message: | Input: "Which product categories have shown consistent growth quarter over quarter in 2008, and what were their top selling items?" Output: { - "decomposed_messages": [ + "decomposed_user_messages": [ ["Calculate quarterly sales totals by product category for 2008", "For these categories, find their top selling products in 2008"] ], "combination_logic": "First identify growing categories from quarterly analysis, then find their best-selling products", @@ -120,7 +120,7 @@ system_message: | Input: "How many orders did we have in 2008?" Output: { - "decomposed_messages": [ + "decomposed_user_messages": [ ["How many orders did we have in 2008?"] ], "combination_logic": "Direct count query, no combination needed", @@ -132,7 +132,7 @@ system_message: | Input: "Compare the sales performance of our top 5 products in Europe versus North America, including their market share in each region" Output: { - "decomposed_messages": [ + "decomposed_user_messages": [ ["Get total sales by product in European countries"], ["Get total sales by product in North American countries"], ["Calculate total market size for each region", "Find top 5 products by sales in each region"], @@ -146,7 +146,7 @@ system_message: | Input: "Hello, what can you help me with?" Output: { - "decomposed_messages": [ + "decomposed_user_messages": [ ["What are your capabilities?"] ], "combination_logic": "Simple greeting and capability question", From bd4b4b0c9607b46d9f6c76aece072d9c393ac8ce Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 13 Jan 2025 11:53:17 +0000 Subject: [PATCH 10/10] Fix disambiguation --- .../autogen_text_2_sql/autogen_text_2_sql.py | 41 +++++++-- .../parallel_query_solving_agent.py | 85 +++++++++++++++---- .../autogen_text_2_sql/evaluation_utils.py | 2 +- .../inner_autogen_text_2_sql.py | 6 +- .../src/text_2_sql_core/connectors/sql.py | 6 +- .../payloads/interaction_payloads.py | 4 +- ...uation_and_sql_query_generation_agent.yaml | 16 ++-- 7 files changed, 124 insertions(+), 36 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 35a669d..45ca468 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -62,7 +62,7 @@ def termination_condition(self): termination = ( TextMentionTermination("TERMINATE") | SourceMatchTermination("answer_agent") - | TextMentionTermination("requires_user_information_request") + | TextMentionTermination("contains_disambiguation_requests") | MaxMessageTermination(5) ) return termination @@ -131,20 +131,43 @@ def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]: sub_message_results = self.parse_message_content(messages[1].content) logging.info("Decomposed Results: %s", sub_message_results) - return sub_message_results.get("decomposed_user_messages", []) + decomposed_user_messages = sub_message_results.get( + "decomposed_user_messages", [] + ) + + logging.debug( + "Returning decomposed_user_messages: %s", decomposed_user_messages + ) + + return decomposed_user_messages def extract_disambiguation_request( self, messages: list ) -> DismabiguationRequestsPayload: """Extract the disambiguation request from the answer.""" - disambiguation_request = messages[-1].content + all_disambiguation_requests = self.parse_message_content(messages[-1].content) decomposed_user_messages = self.extract_decomposed_user_messages(messages) - return DismabiguationRequestsPayload( - disambiguation_request=disambiguation_request, - decomposed_user_messages=decomposed_user_messages, + request_payload = DismabiguationRequestsPayload( + decomposed_user_messages=decomposed_user_messages ) + for per_question_disambiguation_request in all_disambiguation_requests[ + "disambiguation_requests" + ].values(): + for disambiguation_request in per_question_disambiguation_request: + logging.info( + "Disambiguation Request Identified: %s", disambiguation_request + ) + + request = DismabiguationRequestsPayload.Body.DismabiguationRequest( + agent_question=disambiguation_request["agent_question"], + user_choices=disambiguation_request["user_choices"], + ) + request_payload.body.disambiguation_requests.append(request) + + return request_payload + def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: """Extract the sources from the answer.""" answer = messages[-1].content @@ -169,11 +192,13 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: logging.error(f"Expected dict, got {type(sql_query_results)}") return payload - if "results" not in sql_query_results: + if "database_results" not in sql_query_results: logging.error("No 'results' key in sql_query_results") return payload - for message, sql_query_result_list in sql_query_results["results"].items(): + for message, sql_query_result_list in sql_query_results[ + "database_results" + ].items(): if not sql_query_result_list: # Check if list is empty logging.warning(f"No results for message: {message}") continue diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index b48189b..c93db3e 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -18,6 +18,18 @@ from json import JSONDecodeError import re import os +from pydantic import BaseModel, Field + + +class FilteredParallelMessagesCollection(BaseModel): + database_results: dict[str, list] = Field(default_factory=dict) + disambiguation_requests: dict[str, list] = Field(default_factory=dict) + + def add_identifier(self, identifier): + if identifier not in self.database_results: + self.database_results[identifier] = [] + if identifier not in self.disambiguation_requests: + self.disambiguation_requests[identifier] = [] class ParallelQuerySolvingAgent(BaseChatAgent): @@ -89,7 +101,7 @@ async def on_messages_stream( logging.info(f"Query Rewrites: {message_rewrites}") async def consume_inner_messages_from_agentic_flow( - agentic_flow, identifier, database_results + agentic_flow, identifier, filtered_parallel_messages ): """ Consume the inner messages and append them to the specified list. @@ -101,8 +113,7 @@ async def consume_inner_messages_from_agentic_flow( """ async for inner_message in agentic_flow: # Add message to results dictionary, tagged by the function name - if identifier not in database_results: - database_results[identifier] = [] + filtered_parallel_messages.add_identifier(identifier) logging.info(f"Checking Inner Message: {inner_message}") @@ -122,7 +133,9 @@ async def consume_inner_messages_from_agentic_flow( == "query_execution_with_limit" ): logging.info("Contains query results") - database_results[identifier].append( + filtered_parallel_messages.database_results[ + identifier + ].append( { "sql_query": parsed_message[ "sql_query" @@ -138,6 +151,7 @@ async def consume_inner_messages_from_agentic_flow( # Search for specific message types and add them to the final output object if isinstance(parsed_message, dict): + # Check if the message contains pre-run results if ("contains_pre_run_results" in parsed_message) and ( parsed_message["contains_pre_run_results"] is True ): @@ -145,7 +159,9 @@ async def consume_inner_messages_from_agentic_flow( for pre_run_sql_query, pre_run_result in parsed_message[ "cached_messages_and_schemas" ].items(): - database_results[identifier].append( + filtered_parallel_messages.database_results[ + identifier + ].append( { "sql_query": pre_run_sql_query.replace( "\n", " " @@ -153,6 +169,17 @@ async def consume_inner_messages_from_agentic_flow( "sql_rows": pre_run_result["sql_rows"], } ) + # Check if disambiguation is required + elif ("disambiguation_requests" in parsed_message) and ( + parsed_message["disambiguation_requests"] + ): + logging.info("Contains disambiguation requests") + for disambiguation_request in parsed_message[ + "disambiguation_requests" + ]: + filtered_parallel_messages.disambiguation_requests[ + identifier + ].append(disambiguation_request) except Exception as e: logging.warning(f"Error processing message: {e}") @@ -160,7 +187,7 @@ async def consume_inner_messages_from_agentic_flow( yield inner_message inner_solving_generators = [] - database_results = {} + filtered_parallel_messages = FilteredParallelMessagesCollection() # Convert all_non_database_query to lowercase string and compare all_non_database_query = str( @@ -201,7 +228,7 @@ async def consume_inner_messages_from_agentic_flow( injected_parameters=query_params, ), identifier, - database_results, + filtered_parallel_messages, ) ) @@ -218,17 +245,43 @@ async def consume_inner_messages_from_agentic_flow( yield inner_message # Log final results for debugging or auditing - logging.info(f"Database Results: {database_results}") + logging.info( + "Database Results: %s", filtered_parallel_messages.database_results + ) + logging.info( + "Disambiguation Requests: %s", + filtered_parallel_messages.disambiguation_requests, + ) - # Final response - yield Response( - chat_message=TextMessage( - content=json.dumps( - {"contains_results": True, "results": database_results} + if ( + max(map(len, filtered_parallel_messages.disambiguation_requests.values())) + > 0 + ): + # Final response + yield Response( + chat_message=TextMessage( + content=json.dumps( + { + "contains_disambiguation_requests": True, + "disambiguation_requests": filtered_parallel_messages.disambiguation_requests, + } + ), + source=self.name, ), - source=self.name, - ), - ) + ) + else: + # Final response + yield Response( + chat_message=TextMessage( + content=json.dumps( + { + "contains_database_results": True, + "database_results": filtered_parallel_messages.database_results, + } + ), + source=self.name, + ), + ) async def on_reset(self, cancellation_token: CancellationToken) -> None: pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py b/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py index 938c07f..edb8980 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py @@ -16,7 +16,7 @@ def extract_sql_queries_from_results(results: Dict[str, Any]) -> List[str]: """ queries = [] - if results.get("contains_results") and results.get("results"): + if results.get("contains_database_results") and results.get("results"): for question_results in results["results"].values(): for result in question_results: if isinstance(result, dict) and "sql_query" in result: diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index a5ff6ac..a83000d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -124,7 +124,11 @@ def get_all_agents(self): @property def termination_condition(self): """Define the termination condition for the chat.""" - termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10) + termination = ( + TextMentionTermination("TERMINATE") + | MaxMessageTermination(10) + | TextMentionTermination("disambiguation_request") + ) return termination def unified_selector(self, messages): diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 70ef72a..46feb4d 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -187,7 +187,8 @@ async def query_execution_with_limit( "type": "query_execution_with_limit", "sql_query": sql_query, "sql_rows": result, - } + }, + default=str, ) else: return json.dumps( @@ -195,7 +196,8 @@ async def query_execution_with_limit( "type": "errored_query_execution_with_limit", "sql_query": sql_query, "errors": validation_result, - } + }, + default=str, ) async def query_validation( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 2dd2b55..1e7e621 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -42,8 +42,8 @@ class DismabiguationRequest(InteractionPayloadBase): agent_question: str | None = Field(..., alias="agentQuestion") user_choices: list[str] | None = Field(default=None, alias="userChoices") - disambiguation_requests: list[DismabiguationRequest] = Field( - alias="disambiguationRequests" + disambiguation_requests: list[DismabiguationRequest] | None = Field( + default_factory=list, alias="disambiguationRequests" ) decomposed_user_messages: list[list[str]] = Field( default_factory=list, alias="decomposedUserMessages" diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml index e31ed70..df8c0c0 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml @@ -170,12 +170,16 @@ system_message: If disambiguation needed: { - \"disambiguation\": [{ - \"question\": \"\", - \"matching_columns\": [\"\", \"\"], - \"matching_filter_values\": [\"\", \"\"], - \"other_user_choices\": [\"\", \"\"] - }] + \"disambiguation_requests\": [ + { + \"agent_question\": \"\", + \"user_choices\": [\"\", \"\"] + }, + { + \"agent_question\": \"\", + \"user_choices\": [\"\", \"\"] + } + ] } TERMINATE