4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , sort_or_default
43+ from datafusion .expr import (
44+ EXPR_TYPE_ERROR ,
45+ Expr ,
46+ SortKey ,
47+ expr_list_to_raw_expr_list ,
48+ sort_list_to_raw_sort_list ,
49+ )
4450from datafusion .plan import ExecutionPlan , LogicalPlan
4551from datafusion .record_batch import RecordBatchStream
4652
@@ -286,6 +292,23 @@ def __init__(
286292 self .bloom_filter_ndv = bloom_filter_ndv
287293
288294
295+ def _ensure_expr (value : Expr ) -> expr_internal .Expr :
296+ """Return the internal expression or raise ``TypeError`` if invalid.
297+
298+ Args:
299+ value: Candidate expression.
300+
301+ Returns:
302+ The internal expression representation.
303+
304+ Raises:
305+ TypeError: If ``value`` is not an instance of :class:`Expr`.
306+ """
307+ if not isinstance (value , Expr ):
308+ raise TypeError (EXPR_TYPE_ERROR )
309+ return value .expr
310+
311+
289312class DataFrame :
290313 """Two dimensional table representation of data.
291314
@@ -394,9 +417,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394417 df = df.select("a", col("b"), col("a").alias("alternate_a"))
395418
396419 """
397- exprs_internal = [
398- Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399- ]
420+ exprs_internal = expr_list_to_raw_expr_list (exprs )
400421 return DataFrame (self .df .select (* exprs_internal ))
401422
402423 def drop (self , * columns : str ) -> DataFrame :
@@ -426,7 +447,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
426447 """
427448 df = self .df
428449 for p in predicates :
429- df = df .filter (p . expr )
450+ df = df .filter (_ensure_expr ( p ) )
430451 return DataFrame (df )
431452
432453 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -439,7 +460,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439460 Returns:
440461 DataFrame with the new column.
441462 """
442- return DataFrame (self .df .with_column (name , expr . expr ))
463+ return DataFrame (self .df .with_column (name , _ensure_expr ( expr ) ))
443464
444465 def with_columns (
445466 self , * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
@@ -468,17 +489,24 @@ def with_columns(
468489 def _simplify_expression (
469490 * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
470491 ) -> list [expr_internal .Expr ]:
471- expr_list = []
492+ expr_list : list [ expr_internal . Expr ] = []
472493 for expr in exprs :
473- if isinstance (expr , Expr ):
474- expr_list .append (expr .expr )
475- elif isinstance (expr , Iterable ):
476- expr_list .extend (inner_expr .expr for inner_expr in expr )
494+ if isinstance (expr , str ):
495+ raise TypeError (EXPR_TYPE_ERROR )
496+ if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
497+ expr_value = list (expr )
498+ if any (isinstance (inner , str ) for inner in expr_value ):
499+ raise TypeError (EXPR_TYPE_ERROR )
477500 else :
478- raise NotImplementedError
479- if named_exprs :
480- for alias , expr in named_exprs .items ():
481- expr_list .append (expr .alias (alias ).expr )
501+ expr_value = expr
502+ try :
503+ expr_list .extend (expr_list_to_raw_expr_list (expr_value ))
504+ except TypeError as err :
505+ raise TypeError (EXPR_TYPE_ERROR ) from err
506+ for alias , expr in named_exprs .items ():
507+ if not isinstance (expr , Expr ):
508+ raise TypeError (EXPR_TYPE_ERROR )
509+ expr_list .append (expr .alias (alias ).expr )
482510 return expr_list
483511
484512 expressions = _simplify_expression (* exprs , ** named_exprs )
@@ -503,37 +531,43 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503531 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
504532
505533 def aggregate (
506- self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
534+ self ,
535+ group_by : list [Expr | str ] | Expr | str ,
536+ aggs : list [Expr ] | Expr ,
507537 ) -> DataFrame :
508538 """Aggregates the rows of the current DataFrame.
509539
510540 Args:
511- group_by: List of expressions to group by.
541+ group_by: List of expressions or column names to group by.
512542 aggs: List of expressions to aggregate.
513543
514544 Returns:
515545 DataFrame after aggregation.
516546 """
517- group_by = group_by if isinstance (group_by , list ) else [group_by ]
518- aggs = aggs if isinstance (aggs , list ) else [aggs ]
547+ group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
548+ aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
519549
520- group_by = [e .expr for e in group_by ]
521- aggs = [e .expr for e in aggs ]
522- return DataFrame (self .df .aggregate (group_by , aggs ))
550+ group_by_exprs = expr_list_to_raw_expr_list (group_by_list )
551+ aggs_exprs = []
552+ for agg in aggs_list :
553+ if not isinstance (agg , Expr ):
554+ raise TypeError (EXPR_TYPE_ERROR )
555+ aggs_exprs .append (agg .expr )
556+ return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
523557
524- def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525- """Sort the DataFrame by the specified sorting expressions.
558+ def sort (self , * exprs : SortKey ) -> DataFrame :
559+ """Sort the DataFrame by the specified sorting expressions or column names .
526560
527561 Note that any expression can be turned into a sort expression by
528- calling its` ``sort`` method.
562+ calling its ``sort`` method.
529563
530564 Args:
531- exprs: Sort expressions, applied in order.
565+ exprs: Sort expressions or column names , applied in order.
532566
533567 Returns:
534568 DataFrame after sorting.
535569 """
536- exprs_raw = [ sort_or_default ( expr ) for expr in exprs ]
570+ exprs_raw = sort_list_to_raw_sort_list ( list ( exprs ))
537571 return DataFrame (self .df .sort (* exprs_raw ))
538572
539573 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -757,7 +791,7 @@ def join_on(
757791 Returns:
758792 DataFrame after join.
759793 """
760- exprs = [expr . expr for expr in on_exprs ]
794+ exprs = [_ensure_expr ( expr ) for expr in on_exprs ]
761795 return DataFrame (self .df .join_on (right .df , exprs , how ))
762796
763797 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments