Skip to content

Commit 4415cc9

Browse files
committed
UNPICK
1 parent dd9771c commit 4415cc9

File tree

6 files changed

+52
-218
lines changed

6 files changed

+52
-218
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -126,51 +126,6 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129-
String Columns and Expressions
130-
------------------------------
131-
132-
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133-
existing column. These include:
134-
135-
* :py:meth:`~datafusion.DataFrame.select`
136-
* :py:meth:`~datafusion.DataFrame.sort`
137-
* :py:meth:`~datafusion.DataFrame.drop`
138-
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139-
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140-
141-
For such methods, you can pass column names directly:
142-
143-
.. code-block:: python
144-
145-
from datafusion import col, functions as f
146-
147-
df.sort('id')
148-
df.aggregate('id', [f.count(col('value'))])
149-
150-
The same operation can also be written with explicit column expressions, using either ``col()`` or ``column()``:
151-
152-
.. code-block:: python
153-
154-
from datafusion import col, column, functions as f
155-
156-
df.sort(col('id'))
157-
df.aggregate(column('id'), [f.count(col('value'))])
158-
159-
Note that ``column()`` is an alias of ``col()``, so you can use either name; the example above shows both in action.
160-
161-
Whenever an argument represents an expression—such as in
162-
:py:meth:`~datafusion.DataFrame.filter` or
163-
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
164-
and wrap constant values with ``lit()`` (also available as ``literal()``):
165-
166-
.. code-block:: python
167-
168-
from datafusion import col, lit
169-
df.filter(col('age') > lit(21))
170-
171-
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
172-
constant value.
173-
174129
Terminal Operations
175130
-------------------
176131

python/datafusion/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
556+
file_sort_order: list[list[Expr | SortExpr]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -808,7 +808,7 @@ def register_parquet(
808808
file_extension: str = ".parquet",
809809
skip_metadata: bool = True,
810810
schema: pa.Schema | None = None,
811-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
811+
file_sort_order: list[list[SortExpr]] | None = None,
812812
) -> None:
813813
"""Register a Parquet file as a table.
814814
@@ -1099,7 +1099,7 @@ def read_parquet(
10991099
file_extension: str = ".parquet",
11001100
skip_metadata: bool = True,
11011101
schema: pa.Schema | None = None,
1102-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
1102+
file_sort_order: list[list[Expr | SortExpr]] | None = None,
11031103
) -> DataFrame:
11041104
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11051105

python/datafusion/dataframe.py

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import (
44-
EXPR_TYPE_ERROR,
45-
Expr,
46-
SortExpr,
47-
expr_list_to_raw_expr_list,
48-
sort_list_to_raw_sort_list,
49-
)
43+
from datafusion.expr import Expr, SortExpr, sort_or_default
5044
from datafusion.plan import ExecutionPlan, LogicalPlan
5145
from datafusion.record_batch import RecordBatchStream
5246

@@ -400,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
400394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
401395
402396
"""
403-
exprs_internal = expr_list_to_raw_expr_list(exprs)
397+
exprs_internal = [
398+
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399+
]
404400
return DataFrame(self.df.select(*exprs_internal))
405401

406402
def drop(self, *columns: str) -> DataFrame:
@@ -430,8 +426,6 @@ def filter(self, *predicates: Expr) -> DataFrame:
430426
"""
431427
df = self.df
432428
for p in predicates:
433-
if not isinstance(p, Expr):
434-
raise TypeError(EXPR_TYPE_ERROR)
435429
df = df.filter(p.expr)
436430
return DataFrame(df)
437431

@@ -445,8 +439,6 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
445439
Returns:
446440
DataFrame with the new column.
447441
"""
448-
if not isinstance(expr, Expr):
449-
raise TypeError(EXPR_TYPE_ERROR)
450442
return DataFrame(self.df.with_column(name, expr.expr))
451443

452444
def with_columns(
@@ -476,22 +468,17 @@ def with_columns(
476468
def _simplify_expression(
477469
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
478470
) -> list[expr_internal.Expr]:
479-
expr_list: list[expr_internal.Expr] = []
471+
expr_list = []
480472
for expr in exprs:
481-
if isinstance(expr, str) or (
482-
isinstance(expr, Iterable)
483-
and not isinstance(expr, Expr)
484-
and any(isinstance(inner, str) for inner in expr)
485-
):
486-
raise TypeError(EXPR_TYPE_ERROR)
487-
try:
488-
expr_list.extend(expr_list_to_raw_expr_list(expr))
489-
except TypeError as err:
490-
raise TypeError(EXPR_TYPE_ERROR) from err
491-
for alias, expr in named_exprs.items():
492-
if not isinstance(expr, Expr):
493-
raise TypeError(EXPR_TYPE_ERROR)
494-
expr_list.append(expr.alias(alias).expr)
473+
if isinstance(expr, Expr):
474+
expr_list.append(expr.expr)
475+
elif isinstance(expr, Iterable):
476+
expr_list.extend(inner_expr.expr for inner_expr in expr)
477+
else:
478+
raise NotImplementedError
479+
if named_exprs:
480+
for alias, expr in named_exprs.items():
481+
expr_list.append(expr.alias(alias).expr)
495482
return expr_list
496483

497484
expressions = _simplify_expression(*exprs, **named_exprs)
@@ -516,43 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
516503
return DataFrame(self.df.with_column_renamed(old_name, new_name))
517504

518505
def aggregate(
519-
self,
520-
group_by: list[Expr | str] | Expr | str,
521-
aggs: list[Expr] | Expr,
506+
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
522507
) -> DataFrame:
523508
"""Aggregates the rows of the current DataFrame.
524509
525510
Args:
526-
group_by: List of expressions or column names to group by.
511+
group_by: List of expressions to group by.
527512
aggs: List of expressions to aggregate.
528513
529514
Returns:
530515
DataFrame after aggregation.
531516
"""
532-
group_by_list = group_by if isinstance(group_by, list) else [group_by]
533-
aggs_list = aggs if isinstance(aggs, list) else [aggs]
517+
group_by = group_by if isinstance(group_by, list) else [group_by]
518+
aggs = aggs if isinstance(aggs, list) else [aggs]
534519

535-
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
536-
aggs_exprs = []
537-
for agg in aggs_list:
538-
if not isinstance(agg, Expr):
539-
raise TypeError(EXPR_TYPE_ERROR)
540-
aggs_exprs.append(agg.expr)
541-
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
520+
group_by = [e.expr for e in group_by]
521+
aggs = [e.expr for e in aggs]
522+
return DataFrame(self.df.aggregate(group_by, aggs))
542523

543-
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
544-
"""Sort the DataFrame by the specified sorting expressions or column names.
524+
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525+
"""Sort the DataFrame by the specified sorting expressions.
545526
546527
Note that any expression can be turned into a sort expression by
547-
calling its ``sort`` method.
528+
calling its` ``sort`` method.
548529
549530
Args:
550-
exprs: Sort expressions or column names, applied in order.
531+
exprs: Sort expressions, applied in order.
551532
552533
Returns:
553534
DataFrame after sorting.
554535
"""
555-
exprs_raw = sort_list_to_raw_sort_list(list(exprs))
536+
exprs_raw = [sort_or_default(expr) for expr in exprs]
556537
return DataFrame(self.df.sort(*exprs_raw))
557538

558539
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -776,11 +757,7 @@ def join_on(
776757
Returns:
777758
DataFrame after join.
778759
"""
779-
exprs = []
780-
for expr in on_exprs:
781-
if not isinstance(expr, Expr):
782-
raise TypeError(EXPR_TYPE_ERROR)
783-
exprs.append(expr.expr)
760+
exprs = [expr.expr for expr in on_exprs]
784761
return DataFrame(self.df.join_on(right.df, exprs, how))
785762

786763
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional
2626

2727
import pyarrow as pa
2828

@@ -39,10 +39,6 @@
3939
if TYPE_CHECKING:
4040
from datafusion.plan import LogicalPlan
4141

42-
43-
# Standard error message for invalid expression types
44-
EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
45-
4642
# The following are imported from the internal representation. We may choose to
4743
# give these all proper wrappers, or to simply leave as is. These were added
4844
# in order to support passing the `test_imports` unit test.
@@ -220,26 +216,12 @@
220216

221217

222218
def expr_list_to_raw_expr_list(
223-
expr_list: Optional[Sequence[Expr | str] | Expr | str],
219+
expr_list: Optional[list[Expr] | Expr],
224220
) -> Optional[list[expr_internal.Expr]]:
225-
"""Convert a sequence of expressions or column names to raw expressions."""
226-
if isinstance(expr_list, (Expr, str)):
221+
"""Helper function to convert an optional list to raw expressions."""
222+
if isinstance(expr_list, Expr):
227223
expr_list = [expr_list]
228-
if expr_list is None:
229-
return None
230-
raw_exprs: list[expr_internal.Expr] = []
231-
for e in expr_list:
232-
if isinstance(e, str):
233-
raw_exprs.append(Expr.column(e).expr)
234-
elif isinstance(e, Expr):
235-
raw_exprs.append(e.expr)
236-
else:
237-
error = (
238-
"Expected Expr or column name, found:"
239-
f" {type(e).__name__}. {EXPR_TYPE_ERROR}."
240-
)
241-
raise TypeError(error)
242-
return raw_exprs
224+
return [e.expr for e in expr_list] if expr_list is not None else None
243225

244226

245227
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
@@ -250,27 +232,12 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
250232

251233

252234
def sort_list_to_raw_sort_list(
253-
sort_list: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str],
235+
sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr],
254236
) -> Optional[list[expr_internal.SortExpr]]:
255237
"""Helper function to return an optional sort list to raw variant."""
256-
if isinstance(sort_list, (Expr, SortExpr, str)):
238+
if isinstance(sort_list, (Expr, SortExpr)):
257239
sort_list = [sort_list]
258-
if sort_list is None:
259-
return None
260-
raw_sort_list = []
261-
for item in sort_list:
262-
if isinstance(item, str):
263-
expr_obj = Expr.column(item)
264-
elif isinstance(item, (Expr, SortExpr)):
265-
expr_obj = item
266-
else:
267-
error = (
268-
"Expected Expr or column name, found:"
269-
f" {type(item).__name__}. {EXPR_TYPE_ERROR}."
270-
)
271-
raise TypeError(error)
272-
raw_sort_list.append(sort_or_default(expr_obj))
273-
return raw_sort_list
240+
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None
274241

275242

276243
class Expr:

0 commit comments

Comments
 (0)