Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def load_agent_file(cls, name: str) -> dict:
return load(name.lower())

@classmethod
def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
def get_tool(cls, sql_helper, tool_name: str):
"""Gets the tool based on the tool name.
Args:
----
Expand All @@ -46,7 +46,7 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
)
elif tool_name == "sql_get_column_values_tool":
return FunctionToolAlias(
ai_search_helper.get_column_values,
sql_helper.get_column_values,
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
)
else:
Expand Down Expand Up @@ -88,12 +88,11 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
agent_file = cls.load_agent_file(name)

sql_helper = ConnectorFactory.get_database_connector()
ai_search_helper = ConnectorFactory.get_ai_search_connector()

tools = []
if "tools" in agent_file and len(agent_file["tools"]) > 0:
for tool in agent_file["tools"]:
tools.append(cls.get_tool(sql_helper, ai_search_helper, tool))
tools.append(cls.get_tool(sql_helper, tool))

agent = AssistantAgent(
name=name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
from typing import AsyncGenerator, List, Sequence

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
from autogen_agentchat.base import Response
from autogen_agentchat.messages import (
AgentMessage,
ChatMessage,
TextMessage,
ToolCallResultMessage,
)
from autogen_core import CancellationToken
import json
import logging
from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql
from aiostream import stream
from json import JSONDecodeError
import re


class ParallelQuerySolvingAgent(BaseChatAgent):
Expand Down Expand Up @@ -53,9 +59,6 @@ def parse_inner_message(self, message):
except JSONDecodeError:
pass

# Try to extract JSON from markdown code blocks
import re

json_match = re.search(r"```json\s*(.*?)\s*```", message, re.DOTALL)
if json_match:
try:
Expand Down Expand Up @@ -103,30 +106,42 @@ async def consume_inner_messages_from_agentic_flow(

logging.info(f"Checking Inner Message: {inner_message}")

if isinstance(inner_message, TaskResult) is False:
try:
try:
if isinstance(inner_message, ToolCallResultMessage):
for call_result in inner_message.content:
# Check for SQL query results
parsed_message = self.parse_inner_message(
call_result.content
)
logging.info(f"Inner Loaded: {parsed_message}")

if isinstance(parsed_message, dict):
if (
"type" in parsed_message
and parsed_message["type"]
== "query_execution_with_limit"
):
logging.info("Contains query results")
database_results[identifier].append(
{
"sql_query": parsed_message[
"sql_query"
].replace("\n", " "),
"sql_rows": parsed_message["sql_rows"],
}
)

elif isinstance(inner_message, TextMessage):
parsed_message = self.parse_inner_message(inner_message.content)

logging.info(f"Inner Loaded: {parsed_message}")

# Search for specific message types and add them to the final output object
if isinstance(parsed_message, dict):
if (
"type" in parsed_message
and parsed_message["type"]
== "query_execution_with_limit"
):
database_results[identifier].append(
{
"sql_query": parsed_message[
"sql_query"
].replace("\n", " "),
"sql_rows": parsed_message["sql_rows"],
}
)

if ("contains_pre_run_results" in parsed_message) and (
parsed_message["contains_pre_run_results"] is True
):
logging.info("Contains pre-run results")
for pre_run_sql_query, pre_run_result in parsed_message[
"cached_questions_and_schemas"
].items():
Expand All @@ -139,8 +154,8 @@ async def consume_inner_messages_from_agentic_flow(
}
)

except Exception as e:
logging.warning(f"Error processing message: {e}")
except Exception as e:
logging.warning(f"Error processing message: {e}")

yield inner_message

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
from autogen_core import CancellationToken
from text_2_sql_core.connectors.factory import ConnectorFactory
from text_2_sql_core.custom_agents.sql_query_cache_agent import (
SqlQueryCacheAgentCustomAgent,
)
import json
import logging

Expand All @@ -18,7 +20,7 @@ def __init__(self):
"An agent that fetches the queries from the cache based on the user question.",
)

self.sql_connector = ConnectorFactory.get_database_connector()
self.agent = SqlQueryCacheAgentCustomAgent()

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
Expand Down Expand Up @@ -49,31 +51,9 @@ async def on_messages_stream(
# If not JSON array, process as single question
raise ValueError("Could not load message")

# Initialize results dictionary
cached_results = {
"cached_questions_and_schemas": [],
"contains_pre_run_results": False,
}

# Process each question sequentially
for question in user_questions:
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {question}")
cached_query = await self.sql_connector.fetch_queries_from_cache(
question, injected_parameters=injected_parameters
)

# If any question has pre-run results, set the flag
if cached_query.get("contains_pre_run_results", False):
cached_results["contains_pre_run_results"] = True

# Add the cached results for this question
if cached_query.get("cached_questions_and_schemas"):
cached_results["cached_questions_and_schemas"].extend(
cached_query["cached_questions_and_schemas"]
)

logging.info(f"Final cached results: {cached_results}")
cached_results = await self.agent.process_message(
user_questions, injected_parameters
)
yield Response(
chat_message=TextMessage(
content=json.dumps(cached_results), source=self.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
from autogen_core import CancellationToken
from text_2_sql_core.connectors.factory import ConnectorFactory
import json
import logging
from text_2_sql_core.prompts.load import load
from jinja2 import Template
import asyncio
from text_2_sql_core.custom_agents.sql_schema_selection_agent import (
SqlSchemaSelectionAgentCustomAgent,
)


class SqlSchemaSelectionAgent(BaseChatAgent):
Expand All @@ -21,15 +20,7 @@ def __init__(self, **kwargs):
"An agent that fetches the schemas from the cache based on the user question.",
)

self.ai_search_connector = ConnectorFactory.get_ai_search_connector()

self.open_ai_connector = ConnectorFactory.get_open_ai_connector()

self.sql_connector = ConnectorFactory.get_database_connector()

system_prompt = load("sql_schema_selection_agent")["system_message"]

self.system_prompt = Template(system_prompt).render(kwargs)
self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs)

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
Expand All @@ -49,64 +40,15 @@ async def on_messages(
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
last_response = messages[-1].content

# load the json of the last message and get the user question's

user_questions = json.loads(last_response)

logging.info(f"User questions: {user_questions}")

entity_tasks = []

for user_question in user_questions:
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_question},
]
entity_tasks.append(self.open_ai_connector.run_completion_request(messages))

entity_results = await asyncio.gather(*entity_tasks)

entity_search_tasks = []
column_search_tasks = []

for entity_result in entity_results:
loaded_entity_result = json.loads(entity_result)

logging.info(f"Loaded entity result: {loaded_entity_result}")

for entity_group in loaded_entity_result["entities"]:
entity_search_tasks.append(
self.sql_connector.get_entity_schemas(
" ".join(entity_group), as_json=False
)
)

for filter_condition in loaded_entity_result["filter_conditions"]:
column_search_tasks.append(
self.ai_search_connector.get_column_values(
filter_condition, as_json=False
)
)

schemas_results = await asyncio.gather(*entity_search_tasks)
column_value_results = await asyncio.gather(*column_search_tasks)

# deduplicate schemas
final_schemas = []

for schema_result in schemas_results:
for schema in schema_result:
if schema not in final_schemas:
final_schemas.append(schema)

final_results = {
"COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results,
"SCHEMA_OPTIONS": final_schemas,
}

logging.info(f"Final results: {final_results}")
try:
request_details = json.loads(messages[0].content)
user_questions = request_details["question"]
logging.info(f"Processing questions: {user_questions}")
except json.JSONDecodeError:
# If not JSON array, process as single question
raise ValueError("Could not load message")

final_results = await self.agent.process_message(user_questions)

yield Response(
chat_message=TextMessage(
Expand Down
Loading
Loading