Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions text_2_sql/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Text2Sql__UseQueryCache=<Determines if the Query Cache will be used to speed up
Text2Sql__PreRunQueryCache=<Determines if the results from the Query Cache will be pre-run to speed up answer generation. Defaults to True.> # True or False
Text2Sql__UseColumnValueStore=<Determines if the Column Value Store will be used for schema selection Defaults to True.> # True or False
Text2Sql__GenerateFollowUpSuggestions=<Determines if follow up questions will be generated. Defaults to True.> # True or False
Text2Sql__RowLimit=<Determines the maximum number of rows that will be returned in a query. Defaults to 100.> # Integer

# Open AI Connection Details
OpenAI__CompletionDeployment=<openAICompletionDeploymentId. Used for data dictionary creator>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def set_mode(self):
os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true"
)

# Set the row limit
self.kwargs["row_limit"] = int(os.environ.get("Text2Sql__RowLimit", 100))

def get_all_agents(self):
"""Get all agents for the complete flow."""
# If relationship_paths not provided, use a generic template
Expand All @@ -93,31 +96,31 @@ def get_all_agents(self):
- Entity → Attributes (for entity-specific analysis)
"""

self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
sql_schema_selection_agent = SqlSchemaSelectionAgent(
target_engine=self.target_engine,
**self.kwargs,
)

self.sql_query_correction_agent = LLMAgentCreator.create(
sql_query_correction_agent = LLMAgentCreator.create(
"sql_query_correction_agent",
target_engine=self.target_engine,
**self.kwargs,
)

self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create(
disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create(
"disambiguation_and_sql_query_generation_agent",
target_engine=self.target_engine,
**self.kwargs,
)
agents = [
self.sql_schema_selection_agent,
self.sql_query_correction_agent,
self.disambiguation_and_sql_query_generation_agent,
sql_schema_selection_agent,
sql_query_correction_agent,
disambiguation_and_sql_query_generation_agent,
]

if self.use_query_cache:
self.query_cache_agent = SqlQueryCacheAgent()
agents.append(self.query_cache_agent)
query_cache_agent = SqlQueryCacheAgent()
agents.append(query_cache_agent)

return agents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from text_2_sql_core.connectors.factory import ConnectorFactory
import asyncio
import sqlglot
from sqlglot.expressions import Parameter, Select, Identifier
from sqlglot.expressions import Parameter, Select, Identifier, Literal, Limit
from abc import ABC, abstractmethod
from jinja2 import Template
import json
Expand All @@ -30,6 +30,9 @@ def __init__(self):
os.environ.get("Text2Sql__UseAISearch", "True").lower() == "true"
)

# Set the row limit
self.row_limit = int(os.environ.get("Text2Sql__RowLimit", 100))

# Only initialize AI Search connector if enabled
self.ai_search_connector = (
ConnectorFactory.get_ai_search_connector() if self.use_ai_search else None
Expand Down Expand Up @@ -195,7 +198,9 @@ async def query_execution_with_limit(
) = await self.query_validation(sql_query)

if validation_result and validation_errors is None:
result = await self.query_execution(cleaned_query, cast_to=None, limit=25)
result = await self.query_execution(
cleaned_query, cast_to=None, limit=self.row_limit
)

return json.dumps(
{
Expand Down Expand Up @@ -275,11 +280,13 @@ def handle_node(node):
identifiers.append(node.this)

detected_invalid_identifiers = []
updated_parsed_queries = []

for parsed_query in parsed_queries:
for node in parsed_query.walk():
handle_node(node)

# check for invalid identifiers
for token in expressions + identifiers:
if isinstance(token, Parameter):
identifier = str(token.this.this).upper()
Expand All @@ -298,12 +305,32 @@ def handle_node(node):
logging.error(error_message)
return False, None, error_message

# Add a limit clause to the query if it doesn't already have one
for parsed_query in parsed_queries:
# Add a limit clause to the query if it doesn't already have one
current_limit = parsed_query.args.get("limit")
logging.debug("Current Limit: %s", current_limit)

if current_limit is None or current_limit.value > self.row_limit:
# Create a new LIMIT expression
limit_expr = Limit(expression=Literal.number(self.row_limit))

# Attach it to the query by setting it on the SELECT expression
parsed_query.set("limit", limit_expr)
updated_parsed_queries.append(
parsed_query.sql(dialect=self.database_engine.value.lower())
)
else:
updated_parsed_queries.append(
parsed_query.sql(dialect=self.database_engine.value.lower())
)

except sqlglot.errors.ParseError as e:
logging.error("SQL Query is invalid: %s", e.errors)
return False, None, e.errors
else:
logging.info("SQL Query is valid.")
return True, cleaned_query, None
return True, ";".join(updated_parsed_queries), None

async def fetch_sql_queries_with_schemas_from_cache(
self, question: str, injected_parameters: dict = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,30 @@


class PayloadSource(StrEnum):
"""Payload source enum."""

USER = "user"
ASSISTANT = "assistant"


class PayloadType(StrEnum):
"""Payload type enum."""

ANSWER_WITH_SOURCES = "answer_with_sources"
DISAMBIGUATION_REQUESTS = "disambiguation_requests"
PROCESSING_UPDATE = "processing_update"
USER_MESSAGE = "user_message"


class InteractionPayloadBase(BaseModel):
class PayloadAndBodyBase(BaseModel):
"""Base class for payloads and bodies."""

model_config = ConfigDict(populate_by_name=True, extra="ignore")


class PayloadBase(InteractionPayloadBase):
class PayloadBase(PayloadAndBodyBase):
"""Base class for payloads."""

message_id: str = Field(
..., default_factory=lambda: str(uuid4()), alias="messageId"
)
Expand All @@ -42,12 +50,14 @@ class PayloadBase(InteractionPayloadBase):
payload_type: PayloadType = Field(..., alias="payloadType")
payload_source: PayloadSource = Field(..., alias="payloadSource")

body: InteractionPayloadBase | None = Field(default=None)
body: PayloadAndBodyBase | None = Field(default=None)


class DismabiguationRequestsPayload(PayloadAndBodyBase):
"""Disambiguation requests payload. Handles requests for the end user to response to"""

class DismabiguationRequestsPayload(InteractionPayloadBase):
class Body(InteractionPayloadBase):
class DismabiguationRequest(InteractionPayloadBase):
class Body(PayloadAndBodyBase):
class DismabiguationRequest(PayloadAndBodyBase):
assistant_question: str | None = Field(..., alias="assistantQuestion")
user_choices: list[str] | None = Field(default=None, alias="userChoices")

Expand All @@ -65,16 +75,19 @@ class DismabiguationRequest(InteractionPayloadBase):
body: Body | None = Field(default=None)

def __init__(self, **kwargs):
"""Custom init method to pass kwargs to the body."""
super().__init__(**kwargs)

body_kwargs = kwargs.get("body", kwargs)

self.body = self.Body(**body_kwargs)


class AnswerWithSourcesPayload(InteractionPayloadBase):
class Body(InteractionPayloadBase):
class Source(InteractionPayloadBase):
class AnswerWithSourcesPayload(PayloadAndBodyBase):
"""Answer with sources payload. Handles the answer and sources for the answer. The follow up suggestion property is optional and may be used to provide the user with a follow up suggestion."""

class Body(PayloadAndBodyBase):
class Source(PayloadAndBodyBase):
sql_query: str = Field(alias="sqlQuery")
sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows")

Expand All @@ -94,15 +107,18 @@ class Source(InteractionPayloadBase):
body: Body | None = Field(default=None)

def __init__(self, **kwargs):
"""Custom init method to pass kwargs to the body."""
super().__init__(**kwargs)

body_kwargs = kwargs.get("body", kwargs)

self.body = self.Body(**body_kwargs)


class ProcessingUpdatePayload(InteractionPayloadBase):
class Body(InteractionPayloadBase):
class ProcessingUpdatePayload(PayloadAndBodyBase):
"""Processing update payload. Handles updates to the user on the processing status."""

class Body(PayloadAndBodyBase):
title: str | None = "Processing..."
message: str | None = "Processing..."

Expand All @@ -115,15 +131,18 @@ class Body(InteractionPayloadBase):
body: Body | None = Field(default=None)

def __init__(self, **kwargs):
"""Custom init method to pass kwargs to the body."""
super().__init__(**kwargs)

body_kwargs = kwargs.get("body", kwargs)

self.body = self.Body(**body_kwargs)


class UserMessagePayload(InteractionPayloadBase):
class Body(InteractionPayloadBase):
class UserMessagePayload(PayloadAndBodyBase):
"""User message payload. Handles the user message and injected parameters."""

class Body(PayloadAndBodyBase):
user_message: str = Field(..., alias="userMessage")
injected_parameters: dict = Field(
default_factory=dict, alias="injectedParameters"
Expand Down Expand Up @@ -154,6 +173,7 @@ def add_defaults(cls, values):
body: Body | None = Field(default=None)

def __init__(self, **kwargs):
"""Custom init method to pass kwargs to the body."""
super().__init__(**kwargs)

body_kwargs = kwargs.get("body", kwargs)
Expand All @@ -162,6 +182,8 @@ def __init__(self, **kwargs):


class InteractionPayload(RootModel):
"""Interaction payload. Handles the root payload for the interaction"""

root: UserMessagePayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field(
..., discriminator="payload_type"
)
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ system_message: |
<sql_query_generation_rules>
<engine_specific_rules>
{{ engine_specific_rules }}
Rows returned will be automatically limited to {{ row_limit }}.
</engine_specific_rules>

Your primary focus is on:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ system_message: |

<engine_specific_rules>
{{ engine_specific_rules }}
Rows returned will be automatically limited to {{ row_limit }}.
</engine_specific_rules>

<common_conversions>
Expand Down
Loading