Skip to content

Commit 985be53

Browse files
committed
Introduce LiteralValues
This allows for a interpolation value to be used in a location where a column would be required as its literal value. This corresponds to a placeholder for strings to avoid SQL injection, but for simple types such as null, bool, and numbers (if allow numeric is in the context) the direct value. The naming follows the RewritingValue convention. This is not documented to allow usage testing first.
1 parent b4f65ed commit 985be53

File tree

2 files changed

+71
-37
lines changed

2 files changed

+71
-37
lines changed

src/sql_tstring/__init__.py

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ class RewritingValue(Enum):
4040
IS_NOT_NULL = auto()
4141

4242

43+
class LiteralValue:
44+
def __init__(self, value: typing.Any) -> None:
45+
self.value = value
46+
47+
4348
type AbsentType = typing.Literal[RewritingValue.ABSENT]
4449
Absent: AbsentType = RewritingValue.ABSENT
4550
IsNull = RewritingValue.IS_NULL
@@ -135,32 +140,42 @@ def __exit__(
135140
def _safely_convert_placeholder_value(
136141
value: object,
137142
*,
143+
parent_node: Expression | Function | Group | Literal,
138144
allow_numeric: bool = False,
139145
case_sensitive: set[str] | None = None,
140146
case_insensitive: set[str] | None = None,
141147
value_type: type = str,
142-
) -> str:
143-
if case_sensitive is None:
144-
case_sensitive = set()
145-
if case_insensitive is None:
146-
case_insensitive = set()
147-
148-
if value is None:
149-
return "NULL"
150-
if isinstance(value, bool):
151-
return str(value)
152-
if allow_numeric and isinstance(value, Number):
153-
return str(value)
154-
155-
if not isinstance(value, value_type):
156-
raise ValueError(f"{value} is not valid, must be {value_type}")
157-
if isinstance(value, str) and (
158-
value not in case_sensitive and value.lower() not in case_insensitive
159-
):
160-
raise ValueError(
161-
f"{value} is not valid, must be one of {case_sensitive} or {case_insensitive}"
162-
)
163-
return str(value)
148+
) -> Part | Placeholder:
149+
if isinstance(value, LiteralValue):
150+
if value.value is None:
151+
return Part(text="NULL", parent=parent_node)
152+
elif isinstance(value.value, bool) or (allow_numeric and isinstance(value.value, Number)):
153+
return Part(text=str(value.value), parent=parent_node)
154+
else:
155+
return Placeholder(parent=parent_node, value=value.value)
156+
else:
157+
if case_sensitive is None:
158+
case_sensitive = set()
159+
if case_insensitive is None:
160+
case_insensitive = set()
161+
162+
if value is None:
163+
text = "NULL"
164+
elif isinstance(value, bool):
165+
text = str(value)
166+
elif allow_numeric and isinstance(value, Number):
167+
text = str(value)
168+
elif not isinstance(value, value_type):
169+
raise ValueError(f"{value} is not valid, must be {value_type}")
170+
elif isinstance(value, str) and (
171+
value not in case_sensitive and value.lower() not in case_insensitive
172+
):
173+
raise ValueError(
174+
f"{value} is not valid, must be one of {case_sensitive} or {case_insensitive}"
175+
)
176+
else:
177+
text = str(value)
178+
return Part(text=text, parent=parent_node)
164179

165180

166181
def _print_node(
@@ -301,31 +316,35 @@ def _replace_placeholder(
301316
else:
302317
match placeholder_type:
303318
case PlaceholderType.COLUMN:
304-
text = _safely_convert_placeholder_value(
305-
value, allow_numeric=ctx.allow_numeric, case_sensitive=ctx.columns
319+
new_node = _safely_convert_placeholder_value(
320+
value,
321+
allow_numeric=ctx.allow_numeric,
322+
case_sensitive=ctx.columns,
323+
parent_node=node.parent,
306324
)
307-
new_node = Part(text=text, parent=node.parent)
308325
case PlaceholderType.FRAME:
309-
text = _safely_convert_placeholder_value(value, value_type=int)
310-
new_node = Part(text=text, parent=node.parent)
326+
new_node = _safely_convert_placeholder_value(
327+
value, value_type=int, parent_node=node.parent
328+
)
311329
case PlaceholderType.LOCK:
312-
text = _safely_convert_placeholder_value(
313-
value, case_insensitive={"", "nowait", "skip locked"}
330+
new_node = _safely_convert_placeholder_value(
331+
value, case_insensitive={"", "nowait", "skip locked"}, parent_node=node.parent
314332
)
315-
new_node = Part(text=text, parent=node.parent)
316333
case PlaceholderType.SORT:
317-
text = _safely_convert_placeholder_value(
334+
new_node = _safely_convert_placeholder_value(
318335
value,
319336
allow_numeric=ctx.allow_numeric,
320337
case_sensitive=ctx.columns,
321338
case_insensitive={"asc", "ascending", "desc", "descending"},
339+
parent_node=node.parent,
322340
)
323-
new_node = Part(text=text, parent=node.parent)
324341
case PlaceholderType.TABLE:
325-
text = _safely_convert_placeholder_value(
326-
value, allow_numeric=ctx.allow_numeric, case_sensitive=ctx.tables
342+
new_node = _safely_convert_placeholder_value(
343+
value,
344+
allow_numeric=ctx.allow_numeric,
345+
case_sensitive=ctx.tables,
346+
parent_node=node.parent,
327347
)
328-
new_node = Part(text=text, parent=node.parent)
329348
case _:
330349
if (
331350
value is RewritingValue.IS_NULL or value is RewritingValue.IS_NOT_NULL
@@ -339,7 +358,9 @@ def _replace_placeholder(
339358
new_node = Part(text="NULL", parent=node.parent)
340359
else:
341360
new_node = node
342-
result.append(value) # type: ignore[arg-type]
361+
362+
if isinstance(new_node, Placeholder):
363+
result.append(new_node.value) # type: ignore[arg-type]
343364

344365
if isinstance(node.parent, (Expression, ExpressionGroup, Function, Group)):
345366
node.parent.parts[index] = new_node

tests/test_parameters.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from sql_tstring import RewritingValue, sql, sql_context, t
5+
from sql_tstring import LiteralValue, RewritingValue, sql, sql_context, t
66

77
TZ = "uk"
88

@@ -70,6 +70,19 @@ def test_placeholders(query: str, expected_query: str, expected_values: list[Any
7070
assert (expected_query, expected_values) == sql(query, locals() | globals())
7171

7272

73+
def test_literal_value() -> None:
74+
col = "col"
75+
a = "L1"
76+
al = LiteralValue("L2")
77+
b = LiteralValue(None)
78+
c = LiteralValue(1)
79+
d = LiteralValue(True)
80+
with sql_context(allow_numeric=True, columns={"col"}):
81+
query, values = sql("SELECT {col}, '{a}', {al}, {b}, {c}, {d}", locals())
82+
assert query == "SELECT col , ? , ? , NULL , 1 , True"
83+
assert values == ["L1", "L2"]
84+
85+
7386
@pytest.mark.parametrize(
7487
"query, expected_query, expected_values",
7588
[

0 commit comments

Comments
 (0)