Skip to content

Commit c6db309

Browse files
committed
Support frame clauses
These are used in window functions and would likely only need parameters for offsets, as implemented.
1 parent 5c6bbbd commit c6db309

3 files changed

Lines changed: 101 additions & 31 deletions

File tree

src/sql_tstring/__init__.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,15 @@ def _check_valid(
131131
*,
132132
case_sensitive: set[str] | None = None,
133133
case_insensitive: set[str] | None = None,
134+
value_type: type = str,
134135
) -> None:
135136
if case_sensitive is None:
136137
case_sensitive = set()
137138
if case_insensitive is None:
138139
case_insensitive = set()
139-
if not isinstance(value, str) or (
140+
if not isinstance(value, value_type):
141+
raise ValueError(f"{value} is not valid, must be {value_type}")
142+
if isinstance(value, str) and (
140143
value not in case_sensitive and value.lower() not in case_insensitive
141144
):
142145
raise ValueError(
@@ -274,36 +277,40 @@ def _replace_placeholder(
274277
)
275278
result.append(value)
276279
else:
277-
if clause is not None and clause.text.lower() == "order by":
278-
_check_valid(
279-
value,
280-
case_sensitive=ctx.columns,
281-
case_insensitive={"asc", "ascending", "desc", "descending"},
282-
)
283-
new_node = Part(text=typing.cast(str, value), parent=node.parent)
284-
elif placeholder_type == PlaceholderType.COLUMN:
285-
_check_valid(value, case_sensitive=ctx.columns)
286-
new_node = Part(text=typing.cast(str, value), parent=node.parent)
287-
elif placeholder_type == PlaceholderType.TABLE:
288-
_check_valid(value, case_sensitive=ctx.tables)
289-
new_node = Part(text=typing.cast(str, value), parent=node.parent)
290-
elif placeholder_type == PlaceholderType.LOCK:
291-
_check_valid(value, case_insensitive={"", "nowait", "skip locked"})
292-
new_node = Part(text=typing.cast(str, value), parent=node.parent)
293-
else:
294-
if (
295-
value is RewritingValue.IS_NULL or value is RewritingValue.IS_NOT_NULL
296-
) and placeholder_type == PlaceholderType.VARIABLE_CONDITION:
297-
for part in node.parent.parts:
298-
if isinstance(part, Operator):
299-
if value is RewritingValue.IS_NULL:
300-
part.text = "IS"
301-
else:
302-
part.text = "IS NOT"
303-
new_node = Part(text="NULL", parent=node.parent)
304-
else:
305-
new_node = node
306-
result.append(value)
280+
match placeholder_type:
281+
case PlaceholderType.COLUMN:
282+
_check_valid(value, case_sensitive=ctx.columns)
283+
new_node = Part(text=typing.cast(str, value), parent=node.parent)
284+
case PlaceholderType.FRAME:
285+
_check_valid(value, value_type=int)
286+
new_node = Part(text=str(value), parent=node.parent)
287+
case PlaceholderType.LOCK:
288+
_check_valid(value, case_insensitive={"", "nowait", "skip locked"})
289+
new_node = Part(text=typing.cast(str, value), parent=node.parent)
290+
case PlaceholderType.SORT:
291+
_check_valid(
292+
value,
293+
case_sensitive=ctx.columns,
294+
case_insensitive={"asc", "ascending", "desc", "descending"},
295+
)
296+
new_node = Part(text=typing.cast(str, value), parent=node.parent)
297+
case PlaceholderType.TABLE:
298+
_check_valid(value, case_sensitive=ctx.tables)
299+
new_node = Part(text=typing.cast(str, value), parent=node.parent)
300+
case _:
301+
if (
302+
value is RewritingValue.IS_NULL or value is RewritingValue.IS_NOT_NULL
303+
) and placeholder_type == PlaceholderType.VARIABLE_CONDITION:
304+
for part in node.parent.parts:
305+
if isinstance(part, Operator):
306+
if value is RewritingValue.IS_NULL:
307+
part.text = "IS"
308+
else:
309+
part.text = "IS NOT"
310+
new_node = Part(text="NULL", parent=node.parent)
311+
else:
312+
new_node = node
313+
result.append(value)
307314

308315
if isinstance(node.parent, (Expression, ExpressionGroup, Function, Group)):
309316
node.parent.parts[index] = new_node

src/sql_tstring/parser.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
class PlaceholderType(Enum):
2020
COLUMN = auto()
2121
DISALLOWED = auto()
22+
FRAME = auto()
2223
LOCK = auto()
24+
SORT = auto()
2325
TABLE = auto()
2426
VARIABLE = auto()
2527
VARIABLE_CONDITION = auto()
@@ -117,6 +119,13 @@ class ClauseProperties:
117119
),
118120
},
119121
"order": {
122+
"by": {
123+
"": ClauseProperties(
124+
allow_empty=False, placeholder_type=PlaceholderType.SORT, separators={","}
125+
)
126+
},
127+
},
128+
"partition": {
120129
"by": {
121130
"": ClauseProperties(
122131
allow_empty=False, placeholder_type=PlaceholderType.COLUMN, separators={","}
@@ -152,6 +161,11 @@ class ClauseProperties:
152161
allow_empty=False, placeholder_type=PlaceholderType.TABLE, separators=set()
153162
)
154163
},
164+
"groups": {
165+
"": ClauseProperties(
166+
allow_empty=False, placeholder_type=PlaceholderType.FRAME, separators=set()
167+
)
168+
},
155169
"having": {
156170
"": ClauseProperties(
157171
allow_empty=False,
@@ -172,11 +186,21 @@ class ClauseProperties:
172186
allow_empty=False, placeholder_type=PlaceholderType.VARIABLE, separators=set()
173187
)
174188
},
189+
"range": {
190+
"": ClauseProperties(
191+
allow_empty=False, placeholder_type=PlaceholderType.FRAME, separators=set()
192+
)
193+
},
175194
"returning": {
176195
"": ClauseProperties(
177196
allow_empty=False, placeholder_type=PlaceholderType.DISALLOWED, separators={","}
178197
)
179198
},
199+
"rows": {
200+
"": ClauseProperties(
201+
allow_empty=False, placeholder_type=PlaceholderType.FRAME, separators=set()
202+
)
203+
},
180204
"select": {
181205
"": ClauseProperties(
182206
allow_empty=False, placeholder_type=PlaceholderType.COLUMN, separators={","}

tests/test_identifiers.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,45 @@ def test_order_by_invalid_column() -> None:
3131
sql("SELECT x FROM y ORDER BY {a}, {b}", locals())
3232

3333

34+
def test_partition_by() -> None:
35+
a = RewritingValue.ABSENT
36+
b = "x"
37+
with sql_context(columns={"x"}):
38+
assert ("SELECT x OVER(PARTITION BY x) FROM y", []) == sql(
39+
"SELECT x OVER(PARTITION BY {b}) FROM y", locals()
40+
)
41+
42+
43+
def test_partition_by_invalid_column() -> None:
44+
a = RewritingValue.ABSENT
45+
b = "x"
46+
with pytest.raises(ValueError):
47+
sql("SELECT x OVER(PARTITION BY {b}) FROM y", locals())
48+
49+
50+
@pytest.mark.parametrize(
51+
"frame_clause",
52+
["GROUPS", "RANGE", "ROWS"],
53+
)
54+
def test_frame_clause_int(frame_clause: str) -> None:
55+
a = RewritingValue.ABSENT
56+
b = 2
57+
assert (f"SELECT x OVER(PARTITION BY x {frame_clause} 2 PRECEDING) FROM y", []) == sql(
58+
f"SELECT x OVER(PARTITION BY x {frame_clause} {{b}} PRECEDING) FROM y", locals()
59+
)
60+
61+
62+
@pytest.mark.parametrize(
63+
"frame_clause",
64+
["GROUPS", "RANGE", "ROWS"],
65+
)
66+
def test_frame_clause_invalid(frame_clause: str) -> None:
67+
a = RewritingValue.ABSENT
68+
b = "INVALID"
69+
with pytest.raises(ValueError):
70+
sql(f"SELECT x OVER(PARTITION BY x {frame_clause} {{b}} PRECEDING) FROM y", locals())
71+
72+
3473
@pytest.mark.parametrize(
3574
"lock_type, expected",
3675
(

0 commit comments

Comments
 (0)