Skip to content

Commit 89d9e5f

Browse files
committed
Revert "UNPICK"
This reverts commit b1bdf74.
1 parent b1bdf74 commit 89d9e5f

File tree

4 files changed

+174
-20
lines changed

4 files changed

+174
-20
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,49 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
String Columns and Expressions
130+
------------------------------
131+
132+
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
For such methods, you can pass column names directly:
142+
143+
.. code-block:: python
144+
145+
from datafusion import col, functions as f
146+
147+
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
149+
150+
The same operation can also be written with an explicit column expression:
151+
152+
.. code-block:: python
153+
154+
from datafusion import col, functions as f
155+
156+
df.sort(col('id'))
157+
df.aggregate(col('id'), [f.count(col('value'))])
158+
159+
Whenever an argument represents an expression—such as in
160+
:py:meth:`~datafusion.DataFrame.filter` or
161+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
162+
and wrap constant values with ``lit()`` (also available as ``literal()``):
163+
164+
.. code-block:: python
165+
166+
from datafusion import col, lit
167+
df.filter(col('age') > lit(21))
168+
169+
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
170+
constant value.
171+
129172
Terminal Operations
130173
-------------------
131174

python/datafusion/dataframe.py

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import Expr, SortExpr, _to_expr_list, sort_or_default
4444
from datafusion.plan import ExecutionPlan, LogicalPlan
4545
from datafusion.record_batch import RecordBatchStream
4646

@@ -394,9 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396
"""
397-
exprs_internal = [
398-
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399-
]
397+
exprs_internal = _to_expr_list(exprs)
400398
return DataFrame(self.df.select(*exprs_internal))
401399

402400
def drop(self, *columns: str) -> DataFrame:
@@ -426,6 +424,12 @@ def filter(self, *predicates: Expr) -> DataFrame:
426424
"""
427425
df = self.df
428426
for p in predicates:
427+
if not isinstance(p, Expr):
428+
msg = (
429+
f"Expected Expr, got {type(p).__name__}. "
430+
"Use col() or lit() to construct expressions."
431+
)
432+
raise TypeError(msg)
429433
df = df.filter(p.expr)
430434
return DataFrame(df)
431435

@@ -439,6 +443,12 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439443
Returns:
440444
DataFrame with the new column.
441445
"""
446+
if not isinstance(expr, Expr):
447+
msg = (
448+
f"Expected Expr, got {type(expr).__name__}. "
449+
"Use col() or lit() to construct expressions."
450+
)
451+
raise TypeError(msg)
442452
return DataFrame(self.df.with_column(name, expr.expr))
443453

444454
def with_columns(
@@ -473,11 +483,28 @@ def _simplify_expression(
473483
if isinstance(expr, Expr):
474484
expr_list.append(expr.expr)
475485
elif isinstance(expr, Iterable):
476-
expr_list.extend(inner_expr.expr for inner_expr in expr)
486+
for inner_expr in expr:
487+
if not isinstance(inner_expr, Expr):
488+
msg = (
489+
f"Expected Expr, got {type(inner_expr).__name__}. "
490+
"Use col() or lit() to construct expressions."
491+
)
492+
raise TypeError(msg)
493+
expr_list.append(inner_expr.expr)
477494
else:
478-
raise NotImplementedError
495+
msg = (
496+
f"Expected Expr, got {type(expr).__name__}. "
497+
"Use col() or lit() to construct expressions."
498+
)
499+
raise TypeError(msg)
479500
if named_exprs:
480501
for alias, expr in named_exprs.items():
502+
if not isinstance(expr, Expr):
503+
msg = (
504+
f"Expected Expr, got {type(expr).__name__}. "
505+
"Use col() or lit() to construct expressions."
506+
)
507+
raise TypeError(msg)
481508
expr_list.append(expr.alias(alias).expr)
482509
return expr_list
483510

@@ -503,37 +530,56 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503530
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504531

505532
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
533+
self,
534+
group_by: list[Expr | str] | Expr | str,
535+
aggs: list[Expr] | Expr,
507536
) -> DataFrame:
508537
"""Aggregates the rows of the current DataFrame.
509538
510539
Args:
511-
group_by: List of expressions to group by.
540+
group_by: List of expressions or column names to group by.
512541
aggs: List of expressions to aggregate.
513542
514543
Returns:
515544
DataFrame after aggregation.
516545
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
546+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
547+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
519548

520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
549+
group_by_exprs = [
550+
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
551+
]
552+
aggs_exprs = []
553+
for agg in aggs_list:
554+
if not isinstance(agg, Expr):
555+
msg = (
556+
f"Expected Expr, got {type(agg).__name__}. "
557+
"Use col() or lit() to construct expressions."
558+
)
559+
raise TypeError(msg)
560+
aggs_exprs.append(agg.expr)
561+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
523562

524-
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525-
"""Sort the DataFrame by the specified sorting expressions.
563+
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
564+
"""Sort the DataFrame by the specified sorting expressions or column names.
526565
527566
Note that any expression can be turned into a sort expression by
528-
calling its` ``sort`` method.
567+
calling its ``sort`` method.
529568
530569
Args:
531-
exprs: Sort expressions, applied in order.
570+
exprs: Sort expressions or column names, applied in order.
532571
533572
Returns:
534573
DataFrame after sorting.
535574
"""
536-
exprs_raw = [sort_or_default(expr) for expr in exprs]
575+
expr_seq = [e for e in exprs if not isinstance(e, SortExpr)]
576+
raw_exprs_iter = iter(_to_expr_list(expr_seq))
577+
exprs_raw = []
578+
for e in exprs:
579+
if isinstance(e, SortExpr):
580+
exprs_raw.append(sort_or_default(e))
581+
else:
582+
exprs_raw.append(sort_or_default(Expr(next(raw_exprs_iter))))
537583
return DataFrame(self.df.sort(*exprs_raw))
538584

539585
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -757,7 +803,15 @@ def join_on(
757803
Returns:
758804
DataFrame after join.
759805
"""
760-
exprs = [expr.expr for expr in on_exprs]
806+
exprs = []
807+
for expr in on_exprs:
808+
if not isinstance(expr, Expr):
809+
msg = (
810+
f"Expected Expr, got {type(expr).__name__}. "
811+
"Use col() or lit() to construct expressions."
812+
)
813+
raise TypeError(msg)
814+
exprs.append(expr.expr)
761815
return DataFrame(self.df.join_on(right.df, exprs, how))
762816

763817
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -215,6 +215,11 @@
215215
]
216216

217217

218+
def _to_expr_list(exprs: Sequence[Expr | str]) -> list[expr_internal.Expr]:
219+
"""Convert a sequence of expressions or column names to raw expressions."""
220+
return [Expr.column(e).expr if isinstance(e, str) else e.expr for e in exprs]
221+
222+
218223
def expr_list_to_raw_expr_list(
219224
expr_list: Optional[list[Expr] | Expr],
220225
) -> Optional[list[expr_internal.Expr]]:

python/tests/test_dataframe.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,27 @@ def test_sort(df):
268268
assert table.to_pydict() == expected
269269

270270

271+
def test_sort_string_and_expression_equivalent(df):
272+
from datafusion import col
273+
274+
result_str = df.sort("a").to_pydict()
275+
result_expr = df.sort(col("a")).to_pydict()
276+
assert result_str == result_expr
277+
278+
279+
def test_aggregate_string_and_expression_equivalent(df):
280+
from datafusion import col
281+
282+
result_str = df.aggregate("a", [f.count()]).to_pydict()
283+
result_expr = df.aggregate(col("a"), [f.count()]).to_pydict()
284+
assert result_str == result_expr
285+
286+
287+
def test_filter_string_unsupported(df):
288+
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
289+
df.filter("a > 1")
290+
291+
271292
def test_drop(df):
272293
df = df.drop("c")
273294

@@ -337,6 +358,11 @@ def test_with_column(df):
337358
assert result.column(2) == pa.array([5, 7, 9])
338359

339360

361+
def test_with_column_invalid_expr(df):
362+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
363+
df.with_column("c", "a")
364+
365+
340366
def test_with_columns(df):
341367
df = df.with_columns(
342368
(column("a") + column("b")).alias("c"),
@@ -368,6 +394,13 @@ def test_with_columns(df):
368394
assert result.column(6) == pa.array([5, 7, 9])
369395

370396

397+
def test_with_columns_invalid_expr(df):
398+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
399+
df.with_columns("a")
400+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
401+
df.with_columns(c="a")
402+
403+
371404
def test_cast(df):
372405
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
373406
expected = pa.schema(
@@ -526,6 +559,25 @@ def test_join_on():
526559
assert table.to_pydict() == expected
527560

528561

562+
def test_join_on_invalid_expr():
563+
ctx = SessionContext()
564+
565+
batch = pa.RecordBatch.from_arrays(
566+
[pa.array([1, 2]), pa.array([4, 5])],
567+
names=["a", "b"],
568+
)
569+
df = ctx.create_dataframe([[batch]], "l")
570+
df1 = ctx.create_dataframe([[batch]], "r")
571+
572+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
573+
df.join_on(df1, "a")
574+
575+
576+
def test_aggregate_invalid_aggs(df):
577+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
578+
df.aggregate([], "a")
579+
580+
529581
def test_distinct():
530582
ctx = SessionContext()
531583

0 commit comments

Comments
 (0)