Skip to content

Commit d5f0aa7

Browse files
committed
add SQL expression support for with_columns
1 parent d9c90d2 commit d5f0aa7

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

python/datafusion/dataframe.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,13 @@ def with_column(self, name: str, expr: Expr | str) -> DataFrame:
545545
return DataFrame(self.df.with_column(name, ensure_expr(expr)))
546546

547547
def with_columns(
548-
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
548+
self, *exprs: Expr | str | Iterable[Expr | str], **named_exprs: Expr | str
549549
) -> DataFrame:
550550
"""Add columns to the DataFrame.
551551
552-
By passing expressions, iterables of expressions, or named expressions.
552+
By passing expressions, iterables of expressions, string SQL expressions, or named expressions.
553553
All expressions must be :class:`~datafusion.expr.Expr` objects created via
554-
:func:`datafusion.col` or :func:`datafusion.lit`.
554+
:func:`datafusion.col` or :func:`datafusion.lit` or SQL expressions.
555555
To pass named expressions use the form ``name=Expr``.
556556
557557
Example usage: The following will add 4 columns labeled ``a``, ``b``, ``c``,
@@ -565,14 +565,33 @@ def with_columns(
565565
)
566566
567567
Args:
568-
exprs: Either a single expression or an iterable of expressions to add.
568+
exprs: Either a single expression, an iterable of expressions to add or string SQL expressions.
569569
named_exprs: Named expressions in the form of ``name=expr``
570570
571571
Returns:
572572
DataFrame with the new columns added.
573573
"""
574-
expressions = ensure_expr_list(exprs)
574+
expressions = []
575+
for expr in exprs:
576+
if isinstance(expr, str):
577+
expr = self.parse_sql_expr(expr)
578+
expressions.append(ensure_expr(expr))
579+
elif isinstance(expr, Iterable) and not isinstance(
580+
expr, (Expr, str, bytes, bytearray)
581+
):
582+
expressions.extend(
583+
[
584+
self.parse_sql_expr(e).expr
585+
if isinstance(e, str)
586+
else ensure_expr(e)
587+
for e in expr
588+
]
589+
)
590+
else:
591+
expressions.append(ensure_expr(expr))
592+
575593
for alias, expr in named_exprs.items():
594+
expr = self.parse_sql_expr(expr) if isinstance(expr, str) else expr
576595
ensure_expr(expr)
577596
expressions.append(expr.alias(alias).expr)
578597

python/tests/test_dataframe.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -538,15 +538,35 @@ def test_with_columns(df):
538538
assert result.column(6) == pa.array([5, 7, 9])
539539

540540

541-
def test_with_columns_invalid_expr(df):
542-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
543-
df.with_columns("a")
544-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
545-
df.with_columns(c="a")
546-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
547-
df.with_columns(["a"])
548-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
549-
df.with_columns(c=["a"])
541+
def test_with_columns_str(df):
542+
df = df.with_columns(
543+
"a + b as c",
544+
"a + b as d",
545+
[
546+
"a + b as e",
547+
"a + b as f",
548+
],
549+
g=("a + b"),
550+
)
551+
552+
# execute and collect the first (and only) batch
553+
result = df.collect()[0]
554+
555+
assert result.schema.field(0).name == "a"
556+
assert result.schema.field(1).name == "b"
557+
assert result.schema.field(2).name == "c"
558+
assert result.schema.field(3).name == "d"
559+
assert result.schema.field(4).name == "e"
560+
assert result.schema.field(5).name == "f"
561+
assert result.schema.field(6).name == "g"
562+
563+
assert result.column(0) == pa.array([1, 2, 3])
564+
assert result.column(1) == pa.array([4, 5, 6])
565+
assert result.column(2) == pa.array([5, 7, 9])
566+
assert result.column(3) == pa.array([5, 7, 9])
567+
assert result.column(4) == pa.array([5, 7, 9])
568+
assert result.column(5) == pa.array([5, 7, 9])
569+
assert result.column(6) == pa.array([5, 7, 9])
550570

551571

552572
def test_cast(df):

0 commit comments

Comments
 (0)