Skip to content

Commit 77c900f

Browse files
committed
Support Clause groups and statement separators
These allow for unions to be parsed correctly and for queries to be written with surrounding parenthesis.
1 parent 7af6564 commit 77c900f

3 files changed

Lines changed: 47 additions & 16 deletions

File tree

src/sql_tstring/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from sql_tstring.parser import (
1111
Clause,
12+
ClauseGroup,
1213
Element,
1314
Expression,
1415
ExpressionGroup,
@@ -102,10 +103,10 @@ def sql(
102103
for raw_parsed_query in parsed_queries:
103104
parsed_query = deepcopy(raw_parsed_query)
104105
new_values = _replace_placeholders(parsed_query, 0)
105-
result_str += _print_node(parsed_query, [None] * len(result_values), ctx.dialect)
106+
result_str += " " + _print_node(parsed_query, [None] * len(result_values), ctx.dialect)
106107
result_values.extend(new_values)
107108

108-
return result_str, result_values
109+
return result_str.strip(), result_values
109110

110111

111112
class _ContextManager:
@@ -157,7 +158,12 @@ def _print_node(
157158

158159
match node:
159160
case Statement():
160-
result = " ".join(_print_node(clause, placeholders, dialect) for clause in node.clauses)
161+
addition = " ".join(
162+
_print_node(clause, placeholders, dialect) for clause in node.clauses
163+
)
164+
result = f"{node.separator} {addition}"
165+
case ClauseGroup():
166+
result = f"({" ".join(_print_node(clause, placeholders, dialect) for clause in node.clauses)})" # noqa: E501
161167
case Clause() | ExpressionGroup():
162168
result = ""
163169

src/sql_tstring/parser.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,14 @@ class ClauseProperties:
274274

275275
@dataclass
276276
class Statement:
277-
clauses: list[Clause] = field(default_factory=list)
277+
clauses: list[Clause | ClauseGroup] = field(default_factory=list)
278278
parent: ExpressionGroup | Function | Group | None = None
279+
separator: str = ""
279280

280281

281282
@dataclass
282283
class Clause:
283-
parent: Statement
284+
parent: ClauseGroup | Statement
284285
properties: ClauseProperties
285286
text: str
286287
expressions: list[Expression] = field(init=False)
@@ -290,6 +291,12 @@ def __post_init__(self) -> None:
290291
self.expressions = [Expression(self)]
291292

292293

294+
@dataclass
295+
class ClauseGroup:
296+
parent: Statement
297+
clauses: list[Clause] = field(default_factory=list)
298+
299+
293300
@dataclass
294301
class Expression:
295302
parent: Clause | ExpressionGroup
@@ -351,7 +358,7 @@ class Operator:
351358

352359

353360
type ParentNode = Clause | Expression | ExpressionGroup | Function | Group
354-
type Node = ParentNode | Literal | Statement
361+
type Node = ParentNode | ClauseGroup | Literal | Statement
355362
type Element = Node | Operator | Part | Placeholder
356363

357364

@@ -387,7 +394,7 @@ def _parse_placeholder(
387394
) -> None:
388395
if isinstance(current_node, (Expression, Function, Group, Literal)):
389396
parent = current_node
390-
elif isinstance(current_node, Statement):
397+
elif isinstance(current_node, (Statement, ClauseGroup)):
391398
raise ValueError("Invalid syntax")
392399
else: # Clause | ExpressionGroup
393400
parent = current_node.expressions[-1]
@@ -461,7 +468,7 @@ def _parse_string(
461468
current_node, consumed = _parse_token(
462469
current_node, raw_current_token, current_token, tokens[index:], statements
463470
)
464-
else: # Statement
471+
else: # ClauseGroup | Statement
465472
current_node, consumed = _parse_token(
466473
current_node, raw_current_token, current_token, tokens[index:], statements
467474
)
@@ -472,7 +479,7 @@ def _parse_string(
472479

473480

474481
def _parse_token(
475-
current_node: ParentNode | Statement,
482+
current_node: ParentNode | ClauseGroup | Statement,
476483
raw_current_token: str,
477484
current_token: str,
478485
tokens: list[str],
@@ -482,8 +489,17 @@ def _parse_token(
482489
return _parse_clause(current_node, tokens)
483490
elif current_token == ";":
484491
statements.append(Statement())
492+
statements[-1].separator = ";"
485493
return statements[-1], 1
486-
elif not isinstance(current_node, Statement):
494+
elif current_token == "union":
495+
statements.append(Statement())
496+
statements[-1].separator = raw_current_token
497+
consumed = 1
498+
if tokens[1].lower() == "all":
499+
statements[-1].separator += f" {tokens[1]}"
500+
consumed = 2
501+
return statements[-1], consumed
502+
elif not isinstance(current_node, (ClauseGroup, Statement)):
487503
if current_token in OPERATORS:
488504
return _parse_operator(current_node, tokens)
489505
elif current_token == "'":
@@ -494,17 +510,21 @@ def _parse_token(
494510
return _parse_function(current_node, raw_current_token[:-1])
495511
elif current_token == ")":
496512
current_node = _find_node( # type: ignore[assignment]
497-
current_node, (ExpressionGroup, Function, Group)
513+
current_node, (ExpressionGroup, Function, Group, ClauseGroup)
498514
)
499515
return current_node.parent, 1
500516
else:
501517
return _parse_part(current_node, raw_current_token)
518+
elif isinstance(current_node, Statement) and current_token == "(":
519+
statement_group = ClauseGroup(current_node)
520+
current_node.clauses.append(statement_group)
521+
return statement_group, 1
502522
else:
503523
raise ValueError("Invalid syntax")
504524

505525

506526
def _parse_clause(
507-
current_node: ParentNode | Statement,
527+
current_node: ParentNode | ClauseGroup | Statement,
508528
tokens: list[str],
509529
) -> tuple[Clause, int]:
510530
index = 0
@@ -523,16 +543,16 @@ def _parse_clause(
523543
statement = Statement(parent=current_node)
524544
current_node.expressions[-1].parts.append(statement)
525545
current_node = statement
526-
else: # Clause | Expression | Statement
527-
current_node = _find_node(current_node, Statement)
546+
else: # Clause | Expression | Statement | ClauseGroup
547+
current_node = _find_node(current_node, (Statement, ClauseGroup)) # type: ignore[assignment] # noqa: E501
528548

529549
clause_properties = cast(ClauseProperties, clause_entry[""])
530550
clause = Clause(
531-
parent=current_node,
551+
parent=current_node, # type: ignore[arg-type]
532552
properties=clause_properties,
533553
text=text,
534554
)
535-
current_node.clauses.append(clause)
555+
current_node.clauses.append(clause) # type: ignore[union-attr]
536556
return clause, index
537557

538558

tests/test_parsing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,8 @@ def test_nested() -> None:
105105
inner = t("x = 'a'", locals())
106106
query, _ = sql("SELECT x FROM y WHERE {inner}", locals())
107107
assert query == "SELECT x FROM y WHERE x = 'a'"
108+
109+
110+
def test_opening_parenthesis() -> None:
111+
query, _ = sql("(SELECT x FROM y) UNION ALL (SELECT z FROM y)", locals())
112+
assert query == "(SELECT x FROM y) UNION ALL (SELECT z FROM y)"

0 commit comments

Comments
 (0)