Skip to content

Commit f7cf4d8

Browse files
committed
Revert "revert branch UNPICK"
This reverts commit c720298.
1 parent c720298 commit f7cf4d8

File tree

3 files changed

+81
-12
lines changed

3 files changed

+81
-12
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,49 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
String Columns and Expressions
130+
------------------------------
131+
132+
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
For such methods, you can pass column names directly:
142+
143+
.. code-block:: python
144+
145+
from datafusion import col, functions as f
146+
147+
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
149+
150+
The same operation can also be written with an explicit column expression:
151+
152+
.. code-block:: python
153+
154+
from datafusion import col, functions as f
155+
156+
df.sort(col('id'))
157+
df.aggregate(col('id'), [f.count(col('value'))])
158+
159+
Whenever an argument represents an expression—such as in
160+
:py:meth:`~datafusion.DataFrame.filter` or
161+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
162+
and wrap constant values with ``lit()`` (also available as ``literal()``):
163+
164+
.. code-block:: python
165+
166+
from datafusion import col, lit
167+
df.filter(col('age') > lit(21))
168+
169+
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
170+
constant value.
171+
129172
Terminal Operations
130173
-------------------
131174

python/datafusion/dataframe.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,12 @@ def filter(self, *predicates: Expr) -> DataFrame:
426426
"""
427427
df = self.df
428428
for p in predicates:
429+
if not isinstance(p, Expr):
430+
msg = (
431+
f"Expected Expr, got {type(p).__name__}. "
432+
"Use col() or lit() to construct expressions."
433+
)
434+
raise TypeError(msg)
429435
df = df.filter(p.expr)
430436
return DataFrame(df)
431437

@@ -503,37 +509,44 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503509
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504510

505511
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
512+
self,
513+
group_by: list[Expr | str] | Expr | str,
514+
aggs: list[Expr] | Expr,
507515
) -> DataFrame:
508516
"""Aggregates the rows of the current DataFrame.
509517
510518
Args:
511-
group_by: List of expressions to group by.
519+
group_by: List of expressions or column names to group by.
512520
aggs: List of expressions to aggregate.
513521
514522
Returns:
515523
DataFrame after aggregation.
516524
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
525+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
526+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
519527

520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
528+
group_by_exprs = [
529+
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
530+
]
531+
aggs_exprs = [e.expr for e in aggs_list]
532+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
523533

524-
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525-
"""Sort the DataFrame by the specified sorting expressions.
534+
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
535+
"""Sort the DataFrame by the specified sorting expressions or column names.
526536
527537
Note that any expression can be turned into a sort expression by
528-
calling its` ``sort`` method.
538+
calling its ``sort`` method.
529539
530540
Args:
531-
exprs: Sort expressions, applied in order.
541+
exprs: Sort expressions or column names, applied in order.
532542
533543
Returns:
534544
DataFrame after sorting.
535545
"""
536-
exprs_raw = [sort_or_default(expr) for expr in exprs]
546+
exprs_raw = [
547+
sort_or_default(Expr.column(expr) if isinstance(expr, str) else expr)
548+
for expr in exprs
549+
]
537550
return DataFrame(self.df.sort(*exprs_raw))
538551

539552
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:

python/tests/test_dataframe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,19 @@ def test_sort(df):
268268
assert table.to_pydict() == expected
269269

270270

271+
def test_sort_string_and_expression_equivalent(df):
272+
from datafusion import col
273+
274+
result_str = df.sort("a").to_pydict()
275+
result_expr = df.sort(col("a")).to_pydict()
276+
assert result_str == result_expr
277+
278+
279+
def test_filter_string_unsupported(df):
280+
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
281+
df.filter("a > 1")
282+
283+
271284
def test_drop(df):
272285
df = df.drop("c")
273286

0 commit comments

Comments
 (0)