diff --git a/text_2_sql/.env.example b/text_2_sql/.env.example index e87265c..bca0949 100644 --- a/text_2_sql/.env.example +++ b/text_2_sql/.env.example @@ -6,6 +6,7 @@ Text2Sql__UseQueryCache= # True or False Text2Sql__UseColumnValueStore= # True or False Text2Sql__GenerateFollowUpSuggestions= # True or False +Text2Sql__RowLimit= # Integer # Open AI Connection Details OpenAI__CompletionDeployment= diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 0705cb8..818b7d3 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -79,6 +79,9 @@ def set_mode(self): os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true" ) + # Set the row limit + self.kwargs["row_limit"] = int(os.environ.get("Text2Sql__RowLimit", 100)) + def get_all_agents(self): """Get all agents for the complete flow.""" # If relationship_paths not provided, use a generic template @@ -93,31 +96,31 @@ def get_all_agents(self): - Entity → Attributes (for entity-specific analysis) """ - self.sql_schema_selection_agent = SqlSchemaSelectionAgent( + sql_schema_selection_agent = SqlSchemaSelectionAgent( target_engine=self.target_engine, **self.kwargs, ) - self.sql_query_correction_agent = LLMAgentCreator.create( + sql_query_correction_agent = LLMAgentCreator.create( "sql_query_correction_agent", target_engine=self.target_engine, **self.kwargs, ) - self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create( + disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create( "disambiguation_and_sql_query_generation_agent", target_engine=self.target_engine, **self.kwargs, ) agents = [ - self.sql_schema_selection_agent, - self.sql_query_correction_agent, - self.disambiguation_and_sql_query_generation_agent, + sql_schema_selection_agent, + sql_query_correction_agent, + disambiguation_and_sql_query_generation_agent, ] if self.use_query_cache: - self.query_cache_agent = SqlQueryCacheAgent() - agents.append(self.query_cache_agent) + query_cache_agent = SqlQueryCacheAgent() + agents.append(query_cache_agent) return agents diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index dc95ceb..e31588c 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -6,7 +6,7 @@ from text_2_sql_core.connectors.factory import ConnectorFactory import asyncio import sqlglot -from sqlglot.expressions import Parameter, Select, Identifier +from sqlglot.expressions import Parameter, Select, Identifier, Literal, Limit from abc import ABC, abstractmethod from jinja2 import Template import json @@ -30,6 +30,9 @@ def __init__(self): os.environ.get("Text2Sql__UseAISearch", "True").lower() == "true" ) + # Set the row limit + self.row_limit = int(os.environ.get("Text2Sql__RowLimit", 100)) + # Only initialize AI Search connector if enabled self.ai_search_connector = ( ConnectorFactory.get_ai_search_connector() if self.use_ai_search else None @@ -195,7 +198,9 @@ async def query_execution_with_limit( ) = await self.query_validation(sql_query) if validation_result and validation_errors is None: - result = await self.query_execution(cleaned_query, cast_to=None, limit=25) + result = await self.query_execution( + cleaned_query, cast_to=None, limit=self.row_limit + ) return json.dumps( { @@ -275,11 +280,13 @@ def handle_node(node): identifiers.append(node.this) detected_invalid_identifiers = [] + updated_parsed_queries = [] for parsed_query in parsed_queries: for node in parsed_query.walk(): handle_node(node) + # check for invalid identifiers for token in expressions + identifiers: if isinstance(token, Parameter): identifier = str(token.this.this).upper() @@ -298,12 +305,32 @@ def handle_node(node): logging.error(error_message) return False, None, error_message + # Add a limit clause to the query if it doesn't already have one + for parsed_query in parsed_queries: + # Add a limit clause to the query if it doesn't already have one + current_limit = parsed_query.args.get("limit") + logging.debug("Current Limit: %s", current_limit) + + if current_limit is None or current_limit.value > self.row_limit: + # Create a new LIMIT expression + limit_expr = Limit(expression=Literal.number(self.row_limit)) + + # Attach it to the query by setting it on the SELECT expression + parsed_query.set("limit", limit_expr) + updated_parsed_queries.append( + parsed_query.sql(dialect=self.database_engine.value.lower()) + ) + else: + updated_parsed_queries.append( + parsed_query.sql(dialect=self.database_engine.value.lower()) + ) + except sqlglot.errors.ParseError as e: logging.error("SQL Query is invalid: %s", e.errors) return False, None, e.errors else: logging.info("SQL Query is valid.") - return True, cleaned_query, None + return True, ";".join(updated_parsed_queries), None async def fetch_sql_queries_with_schemas_from_cache( self, question: str, injected_parameters: dict = None diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 62f49c5..bbfec49 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -16,22 +16,30 @@ class PayloadSource(StrEnum): + """Payload source enum.""" + USER = "user" ASSISTANT = "assistant" class PayloadType(StrEnum): + """Payload type enum.""" + ANSWER_WITH_SOURCES = "answer_with_sources" DISAMBIGUATION_REQUESTS = "disambiguation_requests" PROCESSING_UPDATE = "processing_update" USER_MESSAGE = "user_message" -class InteractionPayloadBase(BaseModel): +class PayloadAndBodyBase(BaseModel): + """Base class for payloads and bodies.""" + model_config = ConfigDict(populate_by_name=True, extra="ignore") -class PayloadBase(InteractionPayloadBase): +class PayloadBase(PayloadAndBodyBase): + """Base class for payloads.""" + message_id: str = Field( ..., default_factory=lambda: str(uuid4()), alias="messageId" ) @@ -42,12 +50,14 @@ class PayloadBase(InteractionPayloadBase): payload_type: PayloadType = Field(..., alias="payloadType") payload_source: PayloadSource = Field(..., alias="payloadSource") - body: InteractionPayloadBase | None = Field(default=None) + body: PayloadAndBodyBase | None = Field(default=None) + +class DismabiguationRequestsPayload(PayloadAndBodyBase): + """Disambiguation requests payload. Handles requests for the end user to response to""" -class DismabiguationRequestsPayload(InteractionPayloadBase): - class Body(InteractionPayloadBase): - class DismabiguationRequest(InteractionPayloadBase): + class Body(PayloadAndBodyBase): + class DismabiguationRequest(PayloadAndBodyBase): assistant_question: str | None = Field(..., alias="assistantQuestion") user_choices: list[str] | None = Field(default=None, alias="userChoices") @@ -65,6 +75,7 @@ class DismabiguationRequest(InteractionPayloadBase): body: Body | None = Field(default=None) def __init__(self, **kwargs): + """Custom init method to pass kwargs to the body.""" super().__init__(**kwargs) body_kwargs = kwargs.get("body", kwargs) @@ -72,9 +83,11 @@ def __init__(self, **kwargs): self.body = self.Body(**body_kwargs) -class AnswerWithSourcesPayload(InteractionPayloadBase): - class Body(InteractionPayloadBase): - class Source(InteractionPayloadBase): +class AnswerWithSourcesPayload(PayloadAndBodyBase): + """Answer with sources payload. Handles the answer and sources for the answer. The follow up suggestion property is optional and may be used to provide the user with a follow up suggestion.""" + + class Body(PayloadAndBodyBase): + class Source(PayloadAndBodyBase): sql_query: str = Field(alias="sqlQuery") sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows") @@ -94,6 +107,7 @@ class Source(InteractionPayloadBase): body: Body | None = Field(default=None) def __init__(self, **kwargs): + """Custom init method to pass kwargs to the body.""" super().__init__(**kwargs) body_kwargs = kwargs.get("body", kwargs) @@ -101,8 +115,10 @@ def __init__(self, **kwargs): self.body = self.Body(**body_kwargs) -class ProcessingUpdatePayload(InteractionPayloadBase): - class Body(InteractionPayloadBase): +class ProcessingUpdatePayload(PayloadAndBodyBase): + """Processing update payload. Handles updates to the user on the processing status.""" + + class Body(PayloadAndBodyBase): title: str | None = "Processing..." message: str | None = "Processing..." @@ -115,6 +131,7 @@ class Body(InteractionPayloadBase): body: Body | None = Field(default=None) def __init__(self, **kwargs): + """Custom init method to pass kwargs to the body.""" super().__init__(**kwargs) body_kwargs = kwargs.get("body", kwargs) @@ -122,8 +139,10 @@ def __init__(self, **kwargs): self.body = self.Body(**body_kwargs) -class UserMessagePayload(InteractionPayloadBase): - class Body(InteractionPayloadBase): +class UserMessagePayload(PayloadAndBodyBase): + """User message payload. Handles the user message and injected parameters.""" + + class Body(PayloadAndBodyBase): user_message: str = Field(..., alias="userMessage") injected_parameters: dict = Field( default_factory=dict, alias="injectedParameters" @@ -154,6 +173,7 @@ def add_defaults(cls, values): body: Body | None = Field(default=None) def __init__(self, **kwargs): + """Custom init method to pass kwargs to the body.""" super().__init__(**kwargs) body_kwargs = kwargs.get("body", kwargs) @@ -162,6 +182,8 @@ def __init__(self, **kwargs): class InteractionPayload(RootModel): + """Interaction payload. Handles the root payload for the interaction""" + root: UserMessagePayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field( ..., discriminator="payload_type" ) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml index aa6199c..5126ff7 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml @@ -103,6 +103,7 @@ system_message: | {{ engine_specific_rules }} + Rows returned will be automatically limited to {{ row_limit }}. Your primary focus is on: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml index fd900bf..80a2542 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml @@ -14,6 +14,7 @@ system_message: | {{ engine_specific_rules }} + Rows returned will be automatically limited to {{ row_limit }}.