Skip to content

Commit 8154645

Browse files
committed
Revert "UNPICK"
This reverts commit ca23c74.
1 parent ca23c74 commit 8154645

File tree

4 files changed

+188
-26
lines changed

4 files changed

+188
-26
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,49 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
String Columns and Expressions
130+
------------------------------
131+
132+
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
For such methods, you can pass column names directly:
142+
143+
.. code-block:: python
144+
145+
from datafusion import col, functions as f
146+
147+
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
149+
150+
The same operation can also be written with an explicit column expression:
151+
152+
.. code-block:: python
153+
154+
from datafusion import col, functions as f
155+
156+
df.sort(col('id'))
157+
df.aggregate(col('id'), [f.count(col('value'))])
158+
159+
Whenever an argument represents an expression—such as in
160+
:py:meth:`~datafusion.DataFrame.filter` or
161+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
162+
and wrap constant values with ``lit()`` (also available as ``literal()``):
163+
164+
.. code-block:: python
165+
166+
from datafusion import col, lit
167+
df.filter(col('age') > lit(21))
168+
169+
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
170+
constant value.
171+
129172
Terminal Operations
130173
-------------------
131174

python/datafusion/dataframe.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import (
44+
_EXPR_TYPE_ERROR,
45+
Expr,
46+
SortExpr,
47+
expr_list_to_raw_expr_list,
48+
sort_or_default,
49+
)
4450
from datafusion.plan import ExecutionPlan, LogicalPlan
4551
from datafusion.record_batch import RecordBatchStream
4652

@@ -394,9 +400,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394400
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395401
396402
"""
397-
exprs_internal = [
398-
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399-
]
403+
exprs_internal = expr_list_to_raw_expr_list(exprs)
400404
return DataFrame(self.df.select(*exprs_internal))
401405

402406
def drop(self, *columns: str) -> DataFrame:
@@ -426,7 +430,9 @@ def filter(self, *predicates: Expr) -> DataFrame:
426430
"""
427431
df = self.df
428432
for p in predicates:
429-
df = df.filter(p.expr)
433+
if isinstance(p, str) or not isinstance(p, Expr):
434+
raise TypeError(_EXPR_TYPE_ERROR)
435+
df = df.filter(expr_list_to_raw_expr_list(p)[0])
430436
return DataFrame(df)
431437

432438
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -439,6 +445,8 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439445
Returns:
440446
DataFrame with the new column.
441447
"""
448+
if not isinstance(expr, Expr):
449+
raise TypeError(_EXPR_TYPE_ERROR)
442450
return DataFrame(self.df.with_column(name, expr.expr))
443451

444452
def with_columns(
@@ -472,12 +480,17 @@ def _simplify_expression(
472480
for expr in exprs:
473481
if isinstance(expr, Expr):
474482
expr_list.append(expr.expr)
475-
elif isinstance(expr, Iterable):
476-
expr_list.extend(inner_expr.expr for inner_expr in expr)
483+
elif isinstance(expr, Iterable) and not isinstance(expr, (str, Expr)):
484+
for inner_expr in expr:
485+
if not isinstance(inner_expr, Expr):
486+
raise TypeError(_EXPR_TYPE_ERROR)
487+
expr_list.append(inner_expr.expr)
477488
else:
478-
raise NotImplementedError
489+
raise TypeError(_EXPR_TYPE_ERROR)
479490
if named_exprs:
480491
for alias, expr in named_exprs.items():
492+
if not isinstance(expr, Expr):
493+
raise TypeError(_EXPR_TYPE_ERROR)
481494
expr_list.append(expr.alias(alias).expr)
482495
return expr_list
483496

@@ -503,37 +516,56 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503516
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504517

505518
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
519+
self,
520+
group_by: list[Expr | str] | Expr | str,
521+
aggs: list[Expr] | Expr,
507522
) -> DataFrame:
508523
"""Aggregates the rows of the current DataFrame.
509524
510525
Args:
511-
group_by: List of expressions to group by.
526+
group_by: List of expressions or column names to group by.
512527
aggs: List of expressions to aggregate.
513528
514529
Returns:
515530
DataFrame after aggregation.
516531
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
532+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
533+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
519534

520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
535+
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
536+
aggs_exprs = []
537+
for agg in aggs_list:
538+
if not isinstance(agg, Expr):
539+
raise TypeError(_EXPR_TYPE_ERROR)
540+
aggs_exprs.append(agg.expr)
541+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
523542

524-
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525-
"""Sort the DataFrame by the specified sorting expressions.
543+
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
544+
"""Sort the DataFrame by the specified sorting expressions or column names.
526545
527546
Note that any expression can be turned into a sort expression by
528-
calling its` ``sort`` method.
547+
calling its ``sort`` method.
529548
530549
Args:
531-
exprs: Sort expressions, applied in order.
550+
exprs: Sort expressions or column names, applied in order.
532551
533552
Returns:
534553
DataFrame after sorting.
535554
"""
536-
exprs_raw = [sort_or_default(expr) for expr in exprs]
555+
exprs_raw = []
556+
for e in exprs:
557+
if isinstance(e, SortExpr):
558+
exprs_raw.append(sort_or_default(e))
559+
elif isinstance(e, str):
560+
exprs_raw.append(sort_or_default(Expr.column(e)))
561+
elif isinstance(e, Expr):
562+
exprs_raw.append(sort_or_default(e))
563+
else:
564+
error = (
565+
"Expected Expr or column name, found:"
566+
f" {type(e).__name__}. Use col() or lit() to construct expressions."
567+
)
568+
raise TypeError(error)
537569
return DataFrame(self.df.sort(*exprs_raw))
538570

539571
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -757,7 +789,11 @@ def join_on(
757789
Returns:
758790
DataFrame after join.
759791
"""
760-
exprs = [expr.expr for expr in on_exprs]
792+
exprs = []
793+
for expr in on_exprs:
794+
if not isinstance(expr, Expr):
795+
raise TypeError(_EXPR_TYPE_ERROR)
796+
exprs.append(expr.expr)
761797
return DataFrame(self.df.join_on(right.df, exprs, how))
762798

763799
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -39,6 +39,10 @@
3939
if TYPE_CHECKING:
4040
from datafusion.plan import LogicalPlan
4141

42+
43+
# Standard error message for invalid expression types
44+
_EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
45+
4246
# The following are imported from the internal representation. We may choose to
4347
# give these all proper wrappers, or to simply leave as is. These were added
4448
# in order to support passing the `test_imports` unit test.
@@ -216,12 +220,26 @@
216220

217221

218222
def expr_list_to_raw_expr_list(
219-
expr_list: Optional[list[Expr] | Expr],
223+
expr_list: Optional[Sequence[Expr | str] | Expr | str],
220224
) -> Optional[list[expr_internal.Expr]]:
221-
"""Helper function to convert an optional list to raw expressions."""
222-
if isinstance(expr_list, Expr):
225+
"""Convert a sequence of expressions or column names to raw expressions."""
226+
if isinstance(expr_list, (Expr, str)):
223227
expr_list = [expr_list]
224-
return [e.expr for e in expr_list] if expr_list is not None else None
228+
if expr_list is None:
229+
return None
230+
raw_exprs: list[expr_internal.Expr] = []
231+
for e in expr_list:
232+
if isinstance(e, str):
233+
raw_exprs.append(Expr.column(e).expr)
234+
elif isinstance(e, Expr):
235+
raw_exprs.append(e.expr)
236+
else:
237+
error = (
238+
"Expected Expr or column name, found:"
239+
f" {type(e).__name__}. {_EXPR_TYPE_ERROR}."
240+
)
241+
raise TypeError(error)
242+
return raw_exprs
225243

226244

227245
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:

python/tests/test_dataframe.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
WindowFrame,
3434
column,
3535
literal,
36+
col,
3637
)
3738
from datafusion import (
3839
functions as f,
@@ -227,6 +228,13 @@ def test_select_mixed_expr_string(df):
227228
assert result.column(1) == pa.array([1, 2, 3])
228229

229230

231+
def test_select_unsupported(df):
232+
with pytest.raises(
233+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
234+
):
235+
df.select(1)
236+
237+
230238
def test_filter(df):
231239
df1 = df.filter(column("a") > literal(2)).select(
232240
column("a") + column("b"),
@@ -268,6 +276,32 @@ def test_sort(df):
268276
assert table.to_pydict() == expected
269277

270278

279+
def test_sort_string_and_expression_equivalent(df):
280+
from datafusion import col
281+
282+
result_str = df.sort("a").to_pydict()
283+
result_expr = df.sort(col("a")).to_pydict()
284+
assert result_str == result_expr
285+
286+
287+
def test_sort_unsupported(df):
288+
with pytest.raises(
289+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
290+
):
291+
df.sort(1)
292+
293+
294+
def test_aggregate_string_and_expression_equivalent(df):
295+
result_str = df.aggregate("a", [f.count()]).to_pydict()
296+
result_expr = df.aggregate(col("a"), [f.count()]).to_pydict()
297+
assert result_str == result_expr
298+
299+
300+
def test_filter_string_unsupported(df):
301+
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
302+
df.filter("a > 1")
303+
304+
271305
def test_drop(df):
272306
df = df.drop("c")
273307

@@ -337,6 +371,11 @@ def test_with_column(df):
337371
assert result.column(2) == pa.array([5, 7, 9])
338372

339373

374+
def test_with_column_invalid_expr(df):
375+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
376+
df.with_column("c", "a")
377+
378+
340379
def test_with_columns(df):
341380
df = df.with_columns(
342381
(column("a") + column("b")).alias("c"),
@@ -368,6 +407,13 @@ def test_with_columns(df):
368407
assert result.column(6) == pa.array([5, 7, 9])
369408

370409

410+
def test_with_columns_invalid_expr(df):
411+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
412+
df.with_columns("a")
413+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
414+
df.with_columns(c="a")
415+
416+
371417
def test_cast(df):
372418
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
373419
expected = pa.schema(
@@ -526,6 +572,25 @@ def test_join_on():
526572
assert table.to_pydict() == expected
527573

528574

575+
def test_join_on_invalid_expr():
576+
ctx = SessionContext()
577+
578+
batch = pa.RecordBatch.from_arrays(
579+
[pa.array([1, 2]), pa.array([4, 5])],
580+
names=["a", "b"],
581+
)
582+
df = ctx.create_dataframe([[batch]], "l")
583+
df1 = ctx.create_dataframe([[batch]], "r")
584+
585+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
586+
df.join_on(df1, "a")
587+
588+
589+
def test_aggregate_invalid_aggs(df):
590+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
591+
df.aggregate([], "a")
592+
593+
529594
def test_distinct():
530595
ctx = SessionContext()
531596

0 commit comments

Comments
 (0)