Skip to content

Commit 888fa04

Browse files
Limit string nodes in Polars expressions to constant expressions (#225)
2 parents 64c0297 + 325e0e4 commit 888fa04

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

_duckdb-stubs/__init__.pyi

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,15 +1036,12 @@ class token_type:
10361036
def CaseExpression(condition: Expression, value: Expression) -> Expression: ...
10371037
def CoalesceOperator(*args: Expression) -> Expression: ...
10381038
def ColumnExpression(*args: str) -> Expression: ...
1039-
def ConstantExpression(value: Expression | str) -> Expression: ...
1039+
def ConstantExpression(value: pytyping.Any) -> Expression: ...
10401040
def DefaultExpression() -> Expression: ...
10411041
def FunctionExpression(function_name: str, *args: Expression) -> Expression: ...
1042-
def LambdaExpression(lhs: Expression | str | tuple[str], rhs: Expression) -> Expression: ...
1042+
def LambdaExpression(lhs: pytyping.Any, rhs: Expression) -> Expression: ...
10431043
def SQLExpression(expression: str) -> Expression: ...
1044-
@pytyping.overload
1045-
def StarExpression(*, exclude: Expression | str | tuple[str]) -> Expression: ...
1046-
@pytyping.overload
1047-
def StarExpression() -> Expression: ...
1044+
def StarExpression(*, exclude: pytyping.Any = None) -> Expression: ...
10481045
def aggregate(
10491046
df: pandas.DataFrame,
10501047
aggr_expr: Expression | list[Expression] | str | list[str],

duckdb/polars_io.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,9 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
236236
# String type
237237
if dtype == "String" or dtype == "StringOwned":
238238
# Some new formats may store directly under StringOwned
239-
string_val: object | None = value.get("StringOwned", value.get("String", None))
240-
return f"'{string_val}'"
239+
string_val = value.get("StringOwned", value.get("String", None))
240+
# the string must be a string constant
241+
return str(duckdb.ConstantExpression(string_val))
241242

242243
msg = f"Unsupported scalar type {dtype!s}, with value {value}"
243244
raise NotImplementedError(msg)

tests/fast/arrow/test_polars.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,23 @@ def test_polars_lazy_many_batches(self, duckdb_cursor):
639639

640640
assert res == correct
641641

642+
@pytest.mark.parametrize(
643+
"input_str", ["A'dam", 'answer = "42"', "'; DROP TABLE users; --", "line1\nline2\ttab", "", None]
644+
)
645+
def test_expr_with_sql_in_string_node(self, input_str):
646+
"""SQL in a String node in an expression is treated as a constant expression."""
647+
expected = str(duckdb.ConstantExpression(input_str))
648+
649+
# Regular string
650+
tree = {"Scalar": {"String": input_str}}
651+
result = _pl_tree_to_sql(tree)
652+
assert result == expected
653+
654+
# StringOwned
655+
tree = {"Scalar": {"StringOwned": input_str}}
656+
result = _pl_tree_to_sql(tree)
657+
assert result == expected
658+
642659
def test_invalid_expr_json(self):
643660
bad_key_expr = """
644661
{

0 commit comments

Comments
 (0)