Skip to content

Commit 26e7e70

Browse files
committed
Revert "UNPICK"
This reverts commit 8003528.
1 parent 8003528 commit 26e7e70

File tree

4 files changed

+192
-29
lines changed

4 files changed

+192
-29
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,51 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
String Columns and Expressions
130+
------------------------------
131+
132+
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
For such methods, you can pass column names directly:
142+
143+
.. code-block:: python
144+
145+
from datafusion import col, column, functions as f
146+
147+
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
149+
150+
The same operation can also be written with an explicit column expression:
151+
152+
.. code-block:: python
153+
154+
from datafusion import col, column, functions as f
155+
156+
df.sort(col('id'))
157+
df.aggregate(col('id'), [f.count(col('value'))])
158+
159+
Note that ``column()`` is an alias of ``col()``, so you can use either name.
160+
161+
Whenever an argument represents an expression—such as in
162+
:py:meth:`~datafusion.DataFrame.filter` or
163+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
164+
and wrap constant values with ``lit()`` (also available as ``literal()``):
165+
166+
.. code-block:: python
167+
168+
from datafusion import col, lit
169+
df.filter(col('age') > lit(21))
170+
171+
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
172+
constant value.
173+
129174
Terminal Operations
130175
-------------------
131176

python/datafusion/dataframe.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import (
44+
_EXPR_TYPE_ERROR,
45+
Expr,
46+
SortExpr,
47+
expr_list_to_raw_expr_list,
48+
sort_or_default,
49+
)
4450
from datafusion.plan import ExecutionPlan, LogicalPlan
4551
from datafusion.record_batch import RecordBatchStream
4652

@@ -394,9 +400,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394400
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395401
396402
"""
397-
exprs_internal = [
398-
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399-
]
403+
exprs_internal = expr_list_to_raw_expr_list(exprs)
400404
return DataFrame(self.df.select(*exprs_internal))
401405

402406
def drop(self, *columns: str) -> DataFrame:
@@ -426,7 +430,9 @@ def filter(self, *predicates: Expr) -> DataFrame:
426430
"""
427431
df = self.df
428432
for p in predicates:
429-
df = df.filter(p.expr)
433+
if isinstance(p, str) or not isinstance(p, Expr):
434+
raise TypeError(_EXPR_TYPE_ERROR)
435+
df = df.filter(expr_list_to_raw_expr_list(p)[0])
430436
return DataFrame(df)
431437

432438
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -439,6 +445,8 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439445
Returns:
440446
DataFrame with the new column.
441447
"""
448+
if not isinstance(expr, Expr):
449+
raise TypeError(_EXPR_TYPE_ERROR)
442450
return DataFrame(self.df.with_column(name, expr.expr))
443451

444452
def with_columns(
@@ -470,14 +478,18 @@ def _simplify_expression(
470478
) -> list[expr_internal.Expr]:
471479
expr_list = []
472480
for expr in exprs:
473-
if isinstance(expr, Expr):
474-
expr_list.append(expr.expr)
475-
elif isinstance(expr, Iterable):
476-
expr_list.extend(inner_expr.expr for inner_expr in expr)
477-
else:
478-
raise NotImplementedError
481+
if isinstance(expr, str):
482+
raise TypeError(_EXPR_TYPE_ERROR)
483+
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
484+
if any(not isinstance(inner_expr, Expr) for inner_expr in expr):
485+
raise TypeError(_EXPR_TYPE_ERROR)
486+
elif not isinstance(expr, Expr):
487+
raise TypeError(_EXPR_TYPE_ERROR)
488+
expr_list.extend(expr_list_to_raw_expr_list(expr))
479489
if named_exprs:
480490
for alias, expr in named_exprs.items():
491+
if not isinstance(expr, Expr):
492+
raise TypeError(_EXPR_TYPE_ERROR)
481493
expr_list.append(expr.alias(alias).expr)
482494
return expr_list
483495

@@ -503,37 +515,56 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503515
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504516

505517
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
518+
self,
519+
group_by: list[Expr | str] | Expr | str,
520+
aggs: list[Expr] | Expr,
507521
) -> DataFrame:
508522
"""Aggregates the rows of the current DataFrame.
509523
510524
Args:
511-
group_by: List of expressions to group by.
525+
group_by: List of expressions or column names to group by.
512526
aggs: List of expressions to aggregate.
513527
514528
Returns:
515529
DataFrame after aggregation.
516530
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
531+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
532+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
519533

520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
534+
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
535+
aggs_exprs = []
536+
for agg in aggs_list:
537+
if not isinstance(agg, Expr):
538+
raise TypeError(_EXPR_TYPE_ERROR)
539+
aggs_exprs.append(agg.expr)
540+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
523541

524-
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525-
"""Sort the DataFrame by the specified sorting expressions.
542+
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
543+
"""Sort the DataFrame by the specified sorting expressions or column names.
526544
527545
Note that any expression can be turned into a sort expression by
528-
calling its` ``sort`` method.
546+
calling its ``sort`` method.
529547
530548
Args:
531-
exprs: Sort expressions, applied in order.
549+
exprs: Sort expressions or column names, applied in order.
532550
533551
Returns:
534552
DataFrame after sorting.
535553
"""
536-
exprs_raw = [sort_or_default(expr) for expr in exprs]
554+
exprs_raw = []
555+
for e in exprs:
556+
if isinstance(e, SortExpr):
557+
exprs_raw.append(sort_or_default(e))
558+
elif isinstance(e, str):
559+
exprs_raw.append(sort_or_default(Expr.column(e)))
560+
elif isinstance(e, Expr):
561+
exprs_raw.append(sort_or_default(e))
562+
else:
563+
error = (
564+
"Expected Expr or column name, found:"
565+
f" {type(e).__name__}. {_EXPR_TYPE_ERROR}."
566+
)
567+
raise TypeError(error)
537568
return DataFrame(self.df.sort(*exprs_raw))
538569

539570
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -757,7 +788,11 @@ def join_on(
757788
Returns:
758789
DataFrame after join.
759790
"""
760-
exprs = [expr.expr for expr in on_exprs]
791+
exprs = []
792+
for expr in on_exprs:
793+
if not isinstance(expr, Expr):
794+
raise TypeError(_EXPR_TYPE_ERROR)
795+
exprs.append(expr.expr)
761796
return DataFrame(self.df.join_on(right.df, exprs, how))
762797

763798
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -39,6 +39,10 @@
3939
if TYPE_CHECKING:
4040
from datafusion.plan import LogicalPlan
4141

42+
43+
# Standard error message for invalid expression types
44+
_EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
45+
4246
# The following are imported from the internal representation. We may choose to
4347
# give these all proper wrappers, or to simply leave as is. These were added
4448
# in order to support passing the `test_imports` unit test.
@@ -216,12 +220,26 @@
216220

217221

218222
def expr_list_to_raw_expr_list(
219-
expr_list: Optional[list[Expr] | Expr],
223+
expr_list: Optional[Sequence[Expr | str] | Expr | str],
220224
) -> Optional[list[expr_internal.Expr]]:
221-
"""Helper function to convert an optional list to raw expressions."""
222-
if isinstance(expr_list, Expr):
225+
"""Convert a sequence of expressions or column names to raw expressions."""
226+
if isinstance(expr_list, (Expr, str)):
223227
expr_list = [expr_list]
224-
return [e.expr for e in expr_list] if expr_list is not None else None
228+
if expr_list is None:
229+
return None
230+
raw_exprs: list[expr_internal.Expr] = []
231+
for e in expr_list:
232+
if isinstance(e, str):
233+
raw_exprs.append(Expr.column(e).expr)
234+
elif isinstance(e, Expr):
235+
raw_exprs.append(e.expr)
236+
else:
237+
error = (
238+
"Expected Expr or column name, found:"
239+
f" {type(e).__name__}. {_EXPR_TYPE_ERROR}."
240+
)
241+
raise TypeError(error)
242+
return raw_exprs
225243

226244

227245
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:

python/tests/test_dataframe.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
WindowFrame,
3434
column,
3535
literal,
36+
col,
3637
)
3738
from datafusion import (
3839
functions as f,
@@ -227,6 +228,13 @@ def test_select_mixed_expr_string(df):
227228
assert result.column(1) == pa.array([1, 2, 3])
228229

229230

231+
def test_select_unsupported(df):
232+
with pytest.raises(
233+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
234+
):
235+
df.select(1)
236+
237+
230238
def test_filter(df):
231239
df1 = df.filter(column("a") > literal(2)).select(
232240
column("a") + column("b"),
@@ -268,6 +276,32 @@ def test_sort(df):
268276
assert table.to_pydict() == expected
269277

270278

279+
def test_sort_string_and_expression_equivalent(df):
280+
from datafusion import col
281+
282+
result_str = df.sort("a").to_pydict()
283+
result_expr = df.sort(col("a")).to_pydict()
284+
assert result_str == result_expr
285+
286+
287+
def test_sort_unsupported(df):
288+
with pytest.raises(
289+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
290+
):
291+
df.sort(1)
292+
293+
294+
def test_aggregate_string_and_expression_equivalent(df):
295+
result_str = df.aggregate("a", [f.count()]).to_pydict()
296+
result_expr = df.aggregate(col("a"), [f.count()]).to_pydict()
297+
assert result_str == result_expr
298+
299+
300+
def test_filter_string_unsupported(df):
301+
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
302+
df.filter("a > 1")
303+
304+
271305
def test_drop(df):
272306
df = df.drop("c")
273307

@@ -337,6 +371,11 @@ def test_with_column(df):
337371
assert result.column(2) == pa.array([5, 7, 9])
338372

339373

374+
def test_with_column_invalid_expr(df):
375+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
376+
df.with_column("c", "a")
377+
378+
340379
def test_with_columns(df):
341380
df = df.with_columns(
342381
(column("a") + column("b")).alias("c"),
@@ -368,6 +407,13 @@ def test_with_columns(df):
368407
assert result.column(6) == pa.array([5, 7, 9])
369408

370409

410+
def test_with_columns_invalid_expr(df):
411+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
412+
df.with_columns("a")
413+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
414+
df.with_columns(c="a")
415+
416+
371417
def test_cast(df):
372418
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
373419
expected = pa.schema(
@@ -526,6 +572,25 @@ def test_join_on():
526572
assert table.to_pydict() == expected
527573

528574

575+
def test_join_on_invalid_expr():
576+
ctx = SessionContext()
577+
578+
batch = pa.RecordBatch.from_arrays(
579+
[pa.array([1, 2]), pa.array([4, 5])],
580+
names=["a", "b"],
581+
)
582+
df = ctx.create_dataframe([[batch]], "l")
583+
df1 = ctx.create_dataframe([[batch]], "r")
584+
585+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
586+
df.join_on(df1, "a")
587+
588+
589+
def test_aggregate_invalid_aggs(df):
590+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
591+
df.aggregate([], "a")
592+
593+
529594
def test_distinct():
530595
ctx = SessionContext()
531596

0 commit comments

Comments
 (0)