Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"source": [
"import dotenv\n",
"import logging\n",
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload"
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload"
]
},
{
Expand Down Expand Up @@ -100,7 +100,7 @@
"metadata": {},
"outputs": [],
"source": [
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n",
"async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"What is the total number of sales?\")):\n",
" logging.info(\"Received %s Message from Text2SQL System\", message)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions text_2_sql/autogen/evaluate_autogen_text2sql.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"# Add the src directory to the path\n",
"sys.path.append(str(notebook_dir / \"src\"))\n",
"\n",
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload\n",
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n",
"from autogen_text_2_sql.evaluation_utils import get_final_sql_query\n",
"\n",
"# Configure logging\n",
Expand Down Expand Up @@ -127,7 +127,7 @@
" all_queries = []\n",
" final_query = None\n",
" \n",
" async for message in autogen_text2sql.process_question(QuestionPayload(question=question)):\n",
" async for message in autogen_text2sql.process_user_message(UserMessagePayload(user_message=question)):\n",
" if message.payload_type == \"answer_with_sources\":\n",
" # Extract from results\n",
" if hasattr(message.body, 'results'):\n",
Expand Down
4 changes: 2 additions & 2 deletions text_2_sql/autogen/src/autogen_text_2_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
from text_2_sql_core.payloads.interaction_payloads import QuestionPayload
from text_2_sql_core.payloads.interaction_payloads import UserMessagePayload

__all__ = ["AutoGenText2Sql", "QuestionPayload"]
__all__ = ["AutoGenText2Sql", "UserMessagePayload"]
118 changes: 73 additions & 45 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import re

from text_2_sql_core.payloads.interaction_payloads import (
QuestionPayload,
UserMessagePayload,
AnswerWithSourcesPayload,
DismabiguationRequestPayload,
DismabiguationRequestsPayload,
ProcessingUpdatePayload,
InteractionPayload,
PayloadType,
Expand All @@ -40,16 +40,16 @@ def get_all_agents(self):
# Get current datetime for the Query Rewrite Agent
current_datetime = datetime.now()

self.question_rewrite_agent = LLMAgentCreator.create(
"question_rewrite_agent", current_datetime=current_datetime
self.user_message_rewrite_agent = LLMAgentCreator.create(
"user_message_rewrite_agent", current_datetime=current_datetime
)

self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)

self.answer_agent = LLMAgentCreator.create("answer_agent")

agents = [
self.question_rewrite_agent,
self.user_message_rewrite_agent,
self.parallel_query_solving_agent,
self.answer_agent,
]
Expand All @@ -62,7 +62,7 @@ def termination_condition(self):
termination = (
TextMentionTermination("TERMINATE")
| SourceMatchTermination("answer_agent")
| TextMentionTermination("requires_user_information_request")
| TextMentionTermination("contains_disambiguation_requests")
| MaxMessageTermination(5)
)
return termination
Expand All @@ -73,11 +73,11 @@ def unified_selector(self, messages):
current_agent = messages[-1].source if messages else "user"
decision = None

# If this is the first message start with question_rewrite_agent
# If this is the first message start with user_message_rewrite_agent
if current_agent == "user":
decision = "question_rewrite_agent"
decision = "user_message_rewrite_agent"
# Handle transition after query rewriting
elif current_agent == "question_rewrite_agent":
elif current_agent == "user_message_rewrite_agent":
decision = "parallel_query_solving_agent"
# Handle transition after parallel query solving
elif current_agent == "parallel_query_solving_agent":
Expand All @@ -102,15 +102,6 @@ def agentic_flow(self):
)
return flow

def extract_disambiguation_request(
self, messages: list
) -> DismabiguationRequestPayload:
"""Extract the disambiguation request from the answer."""
disambiguation_request = messages[-1].content
return DismabiguationRequestPayload(
disambiguation_request=disambiguation_request,
)

def parse_message_content(self, content):
"""Parse different message content formats into a dictionary."""
if isinstance(content, (list, dict)):
Expand All @@ -134,6 +125,49 @@ def parse_message_content(self, content):
# If all parsing attempts fail, return the content as-is
return content

def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]:
"""Extract the decomposed messages from the answer."""
# Only load sub-message results if we have a database result
sub_message_results = self.parse_message_content(messages[1].content)
logging.info("Decomposed Results: %s", sub_message_results)

decomposed_user_messages = sub_message_results.get(
"decomposed_user_messages", []
)

logging.debug(
"Returning decomposed_user_messages: %s", decomposed_user_messages
)

return decomposed_user_messages

def extract_disambiguation_request(
self, messages: list
) -> DismabiguationRequestsPayload:
"""Extract the disambiguation request from the answer."""
all_disambiguation_requests = self.parse_message_content(messages[-1].content)

decomposed_user_messages = self.extract_decomposed_user_messages(messages)
request_payload = DismabiguationRequestsPayload(
decomposed_user_messages=decomposed_user_messages
)

for per_question_disambiguation_request in all_disambiguation_requests[
"disambiguation_requests"
].values():
for disambiguation_request in per_question_disambiguation_request:
logging.info(
"Disambiguation Request Identified: %s", disambiguation_request
)

request = DismabiguationRequestsPayload.Body.DismabiguationRequest(
agent_question=disambiguation_request["agent_question"],
user_choices=disambiguation_request["user_choices"],
)
request_payload.body.disambiguation_requests.append(request)

return request_payload

def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
"""Extract the sources from the answer."""
answer = messages[-1].content
Expand All @@ -145,41 +179,35 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
except json.JSONDecodeError:
logging.warning("Unable to read SQL query results: %s", sql_query_results)
sql_query_results = {}
sub_question_results = {}
else:
# Only load sub-question results if we have a database result
sub_question_results = self.parse_message_content(messages[1].content)
logging.info("Sub-Question Results: %s", sub_question_results)

try:
sub_questions = [
sub_question
for sub_question_group in sub_question_results.get("sub_questions", [])
for sub_question in sub_question_group
]
decomposed_user_messages = self.extract_decomposed_user_messages(messages)

logging.info("SQL Query Results: %s", sql_query_results)
payload = AnswerWithSourcesPayload(
answer=answer, sub_questions=sub_questions
answer=answer, decomposed_user_messages=decomposed_user_messages
)

if not isinstance(sql_query_results, dict):
logging.error(f"Expected dict, got {type(sql_query_results)}")
return payload

if "results" not in sql_query_results:
if "database_results" not in sql_query_results:
logging.error("No 'results' key in sql_query_results")
return payload

for question, sql_query_result_list in sql_query_results["results"].items():
for message, sql_query_result_list in sql_query_results[
"database_results"
].items():
if not sql_query_result_list: # Check if list is empty
logging.warning(f"No results for question: {question}")
logging.warning(f"No results for message: {message}")
continue

for sql_query_result in sql_query_result_list:
if not isinstance(sql_query_result, dict):
logging.error(
f"Expected dict for sql_query_result, got {type(sql_query_result)}"
"Expected dict for sql_query_result, got %s",
type(sql_query_result),
)
continue

Expand Down Expand Up @@ -208,47 +236,47 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
answer=f"{answer}\nError processing results: {str(e)}"
)

async def process_question(
async def process_user_message(
self,
question_payload: QuestionPayload,
message_payload: UserMessagePayload,
chat_history: list[InteractionPayload] = None,
) -> AsyncGenerator[InteractionPayload, None]:
"""Process the complete question through the unified system.
"""Process the complete message through the unified system.

Args:
----
task (str): The user question to process.
task (str): The user message to process.
chat_history (list[str], optional): The chat history. Defaults to None.
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.

Returns:
-------
dict: The response from the system.
"""
logging.info("Processing question: %s", question_payload.body.question)
logging.info("Processing message: %s", message_payload.body.user_message)
logging.info("Chat history: %s", chat_history)

agent_input = {
"question": question_payload.body.question,
"message": message_payload.body.user_message,
"chat_history": {},
"injected_parameters": question_payload.body.injected_parameters,
"injected_parameters": message_payload.body.injected_parameters,
}

if chat_history is not None:
# Update input
for idx, chat in enumerate(chat_history):
if chat.root.payload_type == PayloadType.QUESTION:
if chat.root.payload_type == PayloadType.USER_MESSAGE:
# For now only consider the user query
chat_history_key = f"chat_{idx}"
agent_input[chat_history_key] = chat.root.body.question
agent_input[chat_history_key] = chat.root.body.user_message

async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
logging.debug("Message: %s", message)

payload = None

if isinstance(message, TextMessage):
if message.source == "question_rewrite_agent":
if message.source == "user_message_rewrite_agent":
payload = ProcessingUpdatePayload(
message="Rewriting the query...",
)
Expand All @@ -271,10 +299,10 @@ async def process_question(
elif message.messages[-1].source == "parallel_query_solving_agent":
# Load into disambiguation request
payload = self.extract_disambiguation_request(message.messages)
elif message.messages[-1].source == "question_rewrite_agent":
elif message.messages[-1].source == "user_message_rewrite_agent":
# Load into empty response
payload = AnswerWithSourcesPayload(
answer="Apologies, I cannot answer that question as it is not relevant. Please try another question or rephrase your current question."
answer="Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message."
)
else:
logging.error("Unexpected TaskResult: %s", message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_tool(cls, sql_helper, tool_name: str):
elif tool_name == "sql_get_entity_schemas_tool":
return FunctionToolAlias(
sql_helper.get_entity_schemas,
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user input and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
)
elif tool_name == "sql_get_column_values_tool":
return FunctionToolAlias(
Expand Down
Loading
Loading