Skip to content

Commit ca23c74

Browse files
committed
UNPICK
1 parent 767df54 commit ca23c74

File tree

4 files changed

+26
-188
lines changed

4 files changed

+26
-188
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -126,49 +126,6 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129-
String Columns and Expressions
130-
------------------------------
131-
132-
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133-
existing column. These include:
134-
135-
* :py:meth:`~datafusion.DataFrame.select`
136-
* :py:meth:`~datafusion.DataFrame.sort`
137-
* :py:meth:`~datafusion.DataFrame.drop`
138-
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139-
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140-
141-
For such methods, you can pass column names directly:
142-
143-
.. code-block:: python
144-
145-
from datafusion import col, functions as f
146-
147-
df.sort('id')
148-
df.aggregate('id', [f.count(col('value'))])
149-
150-
The same operation can also be written with an explicit column expression:
151-
152-
.. code-block:: python
153-
154-
from datafusion import col, functions as f
155-
156-
df.sort(col('id'))
157-
df.aggregate(col('id'), [f.count(col('value'))])
158-
159-
Whenever an argument represents an expression—such as in
160-
:py:meth:`~datafusion.DataFrame.filter` or
161-
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
162-
and wrap constant values with ``lit()`` (also available as ``literal()``):
163-
164-
.. code-block:: python
165-
166-
from datafusion import col, lit
167-
df.filter(col('age') > lit(21))
168-
169-
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
170-
constant value.
171-
172129
Terminal Operations
173130
-------------------
174131

python/datafusion/dataframe.py

Lines changed: 21 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import (
44-
_EXPR_TYPE_ERROR,
45-
Expr,
46-
SortExpr,
47-
expr_list_to_raw_expr_list,
48-
sort_or_default,
49-
)
43+
from datafusion.expr import Expr, SortExpr, sort_or_default
5044
from datafusion.plan import ExecutionPlan, LogicalPlan
5145
from datafusion.record_batch import RecordBatchStream
5246

@@ -400,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
400394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
401395
402396
"""
403-
exprs_internal = expr_list_to_raw_expr_list(exprs)
397+
exprs_internal = [
398+
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399+
]
404400
return DataFrame(self.df.select(*exprs_internal))
405401

406402
def drop(self, *columns: str) -> DataFrame:
@@ -430,9 +426,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
430426
"""
431427
df = self.df
432428
for p in predicates:
433-
if isinstance(p, str) or not isinstance(p, Expr):
434-
raise TypeError(_EXPR_TYPE_ERROR)
435-
df = df.filter(expr_list_to_raw_expr_list(p)[0])
429+
df = df.filter(p.expr)
436430
return DataFrame(df)
437431

438432
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -445,8 +439,6 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
445439
Returns:
446440
DataFrame with the new column.
447441
"""
448-
if not isinstance(expr, Expr):
449-
raise TypeError(_EXPR_TYPE_ERROR)
450442
return DataFrame(self.df.with_column(name, expr.expr))
451443

452444
def with_columns(
@@ -480,17 +472,12 @@ def _simplify_expression(
480472
for expr in exprs:
481473
if isinstance(expr, Expr):
482474
expr_list.append(expr.expr)
483-
elif isinstance(expr, Iterable) and not isinstance(expr, (str, Expr)):
484-
for inner_expr in expr:
485-
if not isinstance(inner_expr, Expr):
486-
raise TypeError(_EXPR_TYPE_ERROR)
487-
expr_list.append(inner_expr.expr)
475+
elif isinstance(expr, Iterable):
476+
expr_list.extend(inner_expr.expr for inner_expr in expr)
488477
else:
489-
raise TypeError(_EXPR_TYPE_ERROR)
478+
raise NotImplementedError
490479
if named_exprs:
491480
for alias, expr in named_exprs.items():
492-
if not isinstance(expr, Expr):
493-
raise TypeError(_EXPR_TYPE_ERROR)
494481
expr_list.append(expr.alias(alias).expr)
495482
return expr_list
496483

@@ -516,56 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
516503
return DataFrame(self.df.with_column_renamed(old_name, new_name))
517504

518505
def aggregate(
519-
self,
520-
group_by: list[Expr | str] | Expr | str,
521-
aggs: list[Expr] | Expr,
506+
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
522507
) -> DataFrame:
523508
"""Aggregates the rows of the current DataFrame.
524509
525510
Args:
526-
group_by: List of expressions or column names to group by.
511+
group_by: List of expressions to group by.
527512
aggs: List of expressions to aggregate.
528513
529514
Returns:
530515
DataFrame after aggregation.
531516
"""
532-
group_by_list = group_by if isinstance(group_by, list) else [group_by]
533-
aggs_list = aggs if isinstance(aggs, list) else [aggs]
517+
group_by = group_by if isinstance(group_by, list) else [group_by]
518+
aggs = aggs if isinstance(aggs, list) else [aggs]
534519

535-
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
536-
aggs_exprs = []
537-
for agg in aggs_list:
538-
if not isinstance(agg, Expr):
539-
raise TypeError(_EXPR_TYPE_ERROR)
540-
aggs_exprs.append(agg.expr)
541-
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
520+
group_by = [e.expr for e in group_by]
521+
aggs = [e.expr for e in aggs]
522+
return DataFrame(self.df.aggregate(group_by, aggs))
542523

543-
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
544-
"""Sort the DataFrame by the specified sorting expressions or column names.
524+
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525+
"""Sort the DataFrame by the specified sorting expressions.
545526
546527
Note that any expression can be turned into a sort expression by
547-
calling its ``sort`` method.
528+
calling its` ``sort`` method.
548529
549530
Args:
550-
exprs: Sort expressions or column names, applied in order.
531+
exprs: Sort expressions, applied in order.
551532
552533
Returns:
553534
DataFrame after sorting.
554535
"""
555-
exprs_raw = []
556-
for e in exprs:
557-
if isinstance(e, SortExpr):
558-
exprs_raw.append(sort_or_default(e))
559-
elif isinstance(e, str):
560-
exprs_raw.append(sort_or_default(Expr.column(e)))
561-
elif isinstance(e, Expr):
562-
exprs_raw.append(sort_or_default(e))
563-
else:
564-
error = (
565-
"Expected Expr or column name, found:"
566-
f" {type(e).__name__}. Use col() or lit() to construct expressions."
567-
)
568-
raise TypeError(error)
536+
exprs_raw = [sort_or_default(expr) for expr in exprs]
569537
return DataFrame(self.df.sort(*exprs_raw))
570538

571539
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -789,11 +757,7 @@ def join_on(
789757
Returns:
790758
DataFrame after join.
791759
"""
792-
exprs = []
793-
for expr in on_exprs:
794-
if not isinstance(expr, Expr):
795-
raise TypeError(_EXPR_TYPE_ERROR)
796-
exprs.append(expr.expr)
760+
exprs = [expr.expr for expr in on_exprs]
797761
return DataFrame(self.df.join_on(right.df, exprs, how))
798762

799763
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional
2626

2727
import pyarrow as pa
2828

@@ -39,10 +39,6 @@
3939
if TYPE_CHECKING:
4040
from datafusion.plan import LogicalPlan
4141

42-
43-
# Standard error message for invalid expression types
44-
_EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
45-
4642
# The following are imported from the internal representation. We may choose to
4743
# give these all proper wrappers, or to simply leave as is. These were added
4844
# in order to support passing the `test_imports` unit test.
@@ -220,26 +216,12 @@
220216

221217

222218
def expr_list_to_raw_expr_list(
223-
expr_list: Optional[Sequence[Expr | str] | Expr | str],
219+
expr_list: Optional[list[Expr] | Expr],
224220
) -> Optional[list[expr_internal.Expr]]:
225-
"""Convert a sequence of expressions or column names to raw expressions."""
226-
if isinstance(expr_list, (Expr, str)):
221+
"""Helper function to convert an optional list to raw expressions."""
222+
if isinstance(expr_list, Expr):
227223
expr_list = [expr_list]
228-
if expr_list is None:
229-
return None
230-
raw_exprs: list[expr_internal.Expr] = []
231-
for e in expr_list:
232-
if isinstance(e, str):
233-
raw_exprs.append(Expr.column(e).expr)
234-
elif isinstance(e, Expr):
235-
raw_exprs.append(e.expr)
236-
else:
237-
error = (
238-
"Expected Expr or column name, found:"
239-
f" {type(e).__name__}. {_EXPR_TYPE_ERROR}."
240-
)
241-
raise TypeError(error)
242-
return raw_exprs
224+
return [e.expr for e in expr_list] if expr_list is not None else None
243225

244226

245227
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:

python/tests/test_dataframe.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
WindowFrame,
3434
column,
3535
literal,
36-
col,
3736
)
3837
from datafusion import (
3938
functions as f,
@@ -228,13 +227,6 @@ def test_select_mixed_expr_string(df):
228227
assert result.column(1) == pa.array([1, 2, 3])
229228

230229

231-
def test_select_unsupported(df):
232-
with pytest.raises(
233-
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
234-
):
235-
df.select(1)
236-
237-
238230
def test_filter(df):
239231
df1 = df.filter(column("a") > literal(2)).select(
240232
column("a") + column("b"),
@@ -276,32 +268,6 @@ def test_sort(df):
276268
assert table.to_pydict() == expected
277269

278270

279-
def test_sort_string_and_expression_equivalent(df):
280-
from datafusion import col
281-
282-
result_str = df.sort("a").to_pydict()
283-
result_expr = df.sort(col("a")).to_pydict()
284-
assert result_str == result_expr
285-
286-
287-
def test_sort_unsupported(df):
288-
with pytest.raises(
289-
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
290-
):
291-
df.sort(1)
292-
293-
294-
def test_aggregate_string_and_expression_equivalent(df):
295-
result_str = df.aggregate("a", [f.count()]).to_pydict()
296-
result_expr = df.aggregate(col("a"), [f.count()]).to_pydict()
297-
assert result_str == result_expr
298-
299-
300-
def test_filter_string_unsupported(df):
301-
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
302-
df.filter("a > 1")
303-
304-
305271
def test_drop(df):
306272
df = df.drop("c")
307273

@@ -371,11 +337,6 @@ def test_with_column(df):
371337
assert result.column(2) == pa.array([5, 7, 9])
372338

373339

374-
def test_with_column_invalid_expr(df):
375-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
376-
df.with_column("c", "a")
377-
378-
379340
def test_with_columns(df):
380341
df = df.with_columns(
381342
(column("a") + column("b")).alias("c"),
@@ -407,13 +368,6 @@ def test_with_columns(df):
407368
assert result.column(6) == pa.array([5, 7, 9])
408369

409370

410-
def test_with_columns_invalid_expr(df):
411-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
412-
df.with_columns("a")
413-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
414-
df.with_columns(c="a")
415-
416-
417371
def test_cast(df):
418372
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
419373
expected = pa.schema(
@@ -572,25 +526,6 @@ def test_join_on():
572526
assert table.to_pydict() == expected
573527

574528

575-
def test_join_on_invalid_expr():
576-
ctx = SessionContext()
577-
578-
batch = pa.RecordBatch.from_arrays(
579-
[pa.array([1, 2]), pa.array([4, 5])],
580-
names=["a", "b"],
581-
)
582-
df = ctx.create_dataframe([[batch]], "l")
583-
df1 = ctx.create_dataframe([[batch]], "r")
584-
585-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
586-
df.join_on(df1, "a")
587-
588-
589-
def test_aggregate_invalid_aggs(df):
590-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
591-
df.aggregate([], "a")
592-
593-
594529
def test_distinct():
595530
ctx = SessionContext()
596531

0 commit comments

Comments
 (0)