Skip to content

Commit a5304f5

Browse files
committed
Revert "UNPICK"
This reverts commit 44b9d1a.
1 parent 44b9d1a commit a5304f5

File tree

6 files changed

+357
-87
lines changed

6 files changed

+357
-87
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,51 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
String Columns and Expressions
130+
------------------------------
131+
132+
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
For such methods, you can pass column names directly:
142+
143+
.. code-block:: python
144+
145+
from datafusion import col, functions as f
146+
147+
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
149+
150+
The same operation can also be written with explicit column expressions, using either ``col()`` or ``column()``:
151+
152+
.. code-block:: python
153+
154+
from datafusion import col, column, functions as f
155+
156+
df.sort(col('id'))
157+
df.aggregate(column('id'), [f.count(col('value'))])
158+
159+
Note that ``column()`` is an alias of ``col()``, so you can use either name; the example above shows both in action.
160+
161+
Whenever an argument represents an expression—such as in
162+
:py:meth:`~datafusion.DataFrame.filter` or
163+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
164+
and wrap constant values with ``lit()`` (also available as ``literal()``):
165+
166+
.. code-block:: python
167+
168+
from datafusion import col, lit
169+
df.filter(col('age') > lit(21))
170+
171+
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
172+
constant value.
173+
129174
Terminal Operations
130175
-------------------
131176

python/datafusion/context.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from datafusion.catalog import Catalog, CatalogProvider, Table
3333
from datafusion.dataframe import DataFrame
34-
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
34+
from datafusion.expr import SortKey, sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
3636
from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF
3737

@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
556+
file_sort_order: list[list[SortKey]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -567,23 +567,20 @@ def register_listing_table(
567567
table_partition_cols: Partition columns.
568568
file_extension: File extension of the provided table.
569569
schema: The data source schema.
570-
file_sort_order: Sort order for the file.
570+
file_sort_order: Sort order for the file. Each sort key can be
571+
specified as a column name (``str``), an expression
572+
(``Expr``), or a ``SortExpr``.
571573
"""
572574
if table_partition_cols is None:
573575
table_partition_cols = []
574576
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
575-
file_sort_order_raw = (
576-
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
577-
if file_sort_order is not None
578-
else None
579-
)
580577
self.ctx.register_listing_table(
581578
name,
582579
str(path),
583580
table_partition_cols,
584581
file_extension,
585582
schema,
586-
file_sort_order_raw,
583+
self._convert_file_sort_order(file_sort_order),
587584
)
588585

589586
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
@@ -808,7 +805,7 @@ def register_parquet(
808805
file_extension: str = ".parquet",
809806
skip_metadata: bool = True,
810807
schema: pa.Schema | None = None,
811-
file_sort_order: list[list[SortExpr]] | None = None,
808+
file_sort_order: list[list[SortKey]] | None = None,
812809
) -> None:
813810
"""Register a Parquet file as a table.
814811
@@ -827,7 +824,9 @@ def register_parquet(
827824
that may be in the file schema. This can help avoid schema
828825
conflicts due to metadata.
829826
schema: The data source schema.
830-
file_sort_order: Sort order for the file.
827+
file_sort_order: Sort order for the file. Each sort key can be
828+
specified as a column name (``str``), an expression
829+
(``Expr``), or a ``SortExpr``.
831830
"""
832831
if table_partition_cols is None:
833832
table_partition_cols = []
@@ -840,9 +839,7 @@ def register_parquet(
840839
file_extension,
841840
skip_metadata,
842841
schema,
843-
[sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
844-
if file_sort_order is not None
845-
else None,
842+
self._convert_file_sort_order(file_sort_order),
846843
)
847844

848845
def register_csv(
@@ -1099,7 +1096,7 @@ def read_parquet(
10991096
file_extension: str = ".parquet",
11001097
skip_metadata: bool = True,
11011098
schema: pa.Schema | None = None,
1102-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
1099+
file_sort_order: list[list[SortKey]] | None = None,
11031100
) -> DataFrame:
11041101
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11051102
@@ -1116,19 +1113,17 @@ def read_parquet(
11161113
schema: An optional schema representing the parquet files. If None,
11171114
the parquet reader will try to infer it based on data in the
11181115
file.
1119-
file_sort_order: Sort order for the file.
1116+
file_sort_order: Sort order for the file. Each sort key can be
1117+
specified as a column name (``str``), an expression
1118+
(``Expr``), or a ``SortExpr``.
11201119
11211120
Returns:
11221121
DataFrame representation of the read Parquet files
11231122
"""
11241123
if table_partition_cols is None:
11251124
table_partition_cols = []
11261125
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
1127-
file_sort_order = (
1128-
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1129-
if file_sort_order is not None
1130-
else None
1131-
)
1126+
file_sort_order = self._convert_file_sort_order(file_sort_order)
11321127
return DataFrame(
11331128
self.ctx.read_parquet(
11341129
str(path),
@@ -1179,6 +1174,24 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11791174
"""Execute the ``plan`` and return the results."""
11801175
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
11811176

1177+
@staticmethod
1178+
def _convert_file_sort_order(
1179+
file_sort_order: list[list[SortKey]] | None,
1180+
) -> list[list[Any]] | None:
1181+
"""Convert nested ``SortKey`` lists into raw sort representations.
1182+
1183+
Each ``SortKey`` can be a column name string, an ``Expr``, or a
1184+
``SortExpr`` and will be converted using
1185+
:func:`datafusion.expr.sort_list_to_raw_sort_list`.
1186+
"""
1187+
# Convert each ``SortKey`` in the provided sort order to the low-level
1188+
# representation expected by the Rust bindings.
1189+
return (
1190+
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1191+
if file_sort_order is not None
1192+
else None
1193+
)
1194+
11821195
@staticmethod
11831196
def _convert_table_partition_cols(
11841197
table_partition_cols: list[tuple[str, str | pa.DataType]],

python/datafusion/dataframe.py

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import (
44+
EXPR_TYPE_ERROR,
45+
Expr,
46+
SortKey,
47+
expr_list_to_raw_expr_list,
48+
sort_list_to_raw_sort_list,
49+
)
4450
from datafusion.plan import ExecutionPlan, LogicalPlan
4551
from datafusion.record_batch import RecordBatchStream
4652

@@ -286,6 +292,23 @@ def __init__(
286292
self.bloom_filter_ndv = bloom_filter_ndv
287293

288294

295+
def _ensure_expr(value: Expr) -> expr_internal.Expr:
296+
"""Return the internal expression or raise ``TypeError`` if invalid.
297+
298+
Args:
299+
value: Candidate expression.
300+
301+
Returns:
302+
The internal expression representation.
303+
304+
Raises:
305+
TypeError: If ``value`` is not an instance of :class:`Expr`.
306+
"""
307+
if not isinstance(value, Expr):
308+
raise TypeError(EXPR_TYPE_ERROR)
309+
return value.expr
310+
311+
289312
class DataFrame:
290313
"""Two dimensional table representation of data.
291314
@@ -394,9 +417,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394417
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395418
396419
"""
397-
exprs_internal = [
398-
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399-
]
420+
exprs_internal = expr_list_to_raw_expr_list(exprs)
400421
return DataFrame(self.df.select(*exprs_internal))
401422

402423
def drop(self, *columns: str) -> DataFrame:
@@ -426,7 +447,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
426447
"""
427448
df = self.df
428449
for p in predicates:
429-
df = df.filter(p.expr)
450+
df = df.filter(_ensure_expr(p))
430451
return DataFrame(df)
431452

432453
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -439,7 +460,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439460
Returns:
440461
DataFrame with the new column.
441462
"""
442-
return DataFrame(self.df.with_column(name, expr.expr))
463+
return DataFrame(self.df.with_column(name, _ensure_expr(expr)))
443464

444465
def with_columns(
445466
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
@@ -468,17 +489,24 @@ def with_columns(
468489
def _simplify_expression(
469490
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
470491
) -> list[expr_internal.Expr]:
471-
expr_list = []
492+
expr_list: list[expr_internal.Expr] = []
472493
for expr in exprs:
473-
if isinstance(expr, Expr):
474-
expr_list.append(expr.expr)
475-
elif isinstance(expr, Iterable):
476-
expr_list.extend(inner_expr.expr for inner_expr in expr)
494+
if isinstance(expr, str):
495+
raise TypeError(EXPR_TYPE_ERROR)
496+
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
497+
expr_value = list(expr)
498+
if any(isinstance(inner, str) for inner in expr_value):
499+
raise TypeError(EXPR_TYPE_ERROR)
477500
else:
478-
raise NotImplementedError
479-
if named_exprs:
480-
for alias, expr in named_exprs.items():
481-
expr_list.append(expr.alias(alias).expr)
501+
expr_value = expr
502+
try:
503+
expr_list.extend(expr_list_to_raw_expr_list(expr_value))
504+
except TypeError as err:
505+
raise TypeError(EXPR_TYPE_ERROR) from err
506+
for alias, expr in named_exprs.items():
507+
if not isinstance(expr, Expr):
508+
raise TypeError(EXPR_TYPE_ERROR)
509+
expr_list.append(expr.alias(alias).expr)
482510
return expr_list
483511

484512
expressions = _simplify_expression(*exprs, **named_exprs)
@@ -503,37 +531,43 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503531
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504532

505533
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
534+
self,
535+
group_by: list[Expr | str] | Expr | str,
536+
aggs: list[Expr] | Expr,
507537
) -> DataFrame:
508538
"""Aggregates the rows of the current DataFrame.
509539
510540
Args:
511-
group_by: List of expressions to group by.
541+
group_by: List of expressions or column names to group by.
512542
aggs: List of expressions to aggregate.
513543
514544
Returns:
515545
DataFrame after aggregation.
516546
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
547+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
548+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
519549

520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
550+
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
551+
aggs_exprs = []
552+
for agg in aggs_list:
553+
if not isinstance(agg, Expr):
554+
raise TypeError(EXPR_TYPE_ERROR)
555+
aggs_exprs.append(agg.expr)
556+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
523557

524-
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525-
"""Sort the DataFrame by the specified sorting expressions.
558+
def sort(self, *exprs: SortKey) -> DataFrame:
559+
"""Sort the DataFrame by the specified sorting expressions or column names.
526560
527561
Note that any expression can be turned into a sort expression by
528-
calling its` ``sort`` method.
562+
calling its ``sort`` method.
529563
530564
Args:
531-
exprs: Sort expressions, applied in order.
565+
exprs: Sort expressions or column names, applied in order.
532566
533567
Returns:
534568
DataFrame after sorting.
535569
"""
536-
exprs_raw = [sort_or_default(expr) for expr in exprs]
570+
exprs_raw = sort_list_to_raw_sort_list(list(exprs))
537571
return DataFrame(self.df.sort(*exprs_raw))
538572

539573
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -757,7 +791,7 @@ def join_on(
757791
Returns:
758792
DataFrame after join.
759793
"""
760-
exprs = [expr.expr for expr in on_exprs]
794+
exprs = [_ensure_expr(expr) for expr in on_exprs]
761795
return DataFrame(self.df.join_on(right.df, exprs, how))
762796

763797
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

0 commit comments

Comments
 (0)