4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , sort_or_default
43+ from datafusion .expr import (
44+ _EXPR_TYPE_ERROR ,
45+ Expr ,
46+ SortExpr ,
47+ expr_list_to_raw_expr_list ,
48+ sort_or_default ,
49+ )
4450from datafusion .plan import ExecutionPlan , LogicalPlan
4551from datafusion .record_batch import RecordBatchStream
4652
@@ -394,9 +400,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394400 df = df.select("a", col("b"), col("a").alias("alternate_a"))
395401
396402 """
397- exprs_internal = [
398- Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399- ]
403+ exprs_internal = expr_list_to_raw_expr_list (exprs )
400404 return DataFrame (self .df .select (* exprs_internal ))
401405
402406 def drop (self , * columns : str ) -> DataFrame :
@@ -426,7 +430,9 @@ def filter(self, *predicates: Expr) -> DataFrame:
426430 """
427431 df = self .df
428432 for p in predicates :
429- df = df .filter (p .expr )
433+ if isinstance (p , str ) or not isinstance (p , Expr ):
434+ raise TypeError (_EXPR_TYPE_ERROR )
435+ df = df .filter (expr_list_to_raw_expr_list (p )[0 ])
430436 return DataFrame (df )
431437
432438 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -439,6 +445,8 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439445 Returns:
440446 DataFrame with the new column.
441447 """
448+ if not isinstance (expr , Expr ):
449+ raise TypeError (_EXPR_TYPE_ERROR )
442450 return DataFrame (self .df .with_column (name , expr .expr ))
443451
444452 def with_columns (
@@ -470,14 +478,18 @@ def _simplify_expression(
470478 ) -> list [expr_internal .Expr ]:
471479 expr_list = []
472480 for expr in exprs :
473- if isinstance (expr , Expr ):
474- expr_list .append (expr .expr )
475- elif isinstance (expr , Iterable ):
476- expr_list .extend (inner_expr .expr for inner_expr in expr )
477- else :
478- raise NotImplementedError
481+ if isinstance (expr , str ):
482+ raise TypeError (_EXPR_TYPE_ERROR )
483+ if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
484+ if any (not isinstance (inner_expr , Expr ) for inner_expr in expr ):
485+ raise TypeError (_EXPR_TYPE_ERROR )
486+ elif not isinstance (expr , Expr ):
487+ raise TypeError (_EXPR_TYPE_ERROR )
488+ expr_list .extend (expr_list_to_raw_expr_list (expr ))
479489 if named_exprs :
480490 for alias , expr in named_exprs .items ():
491+ if not isinstance (expr , Expr ):
492+ raise TypeError (_EXPR_TYPE_ERROR )
481493 expr_list .append (expr .alias (alias ).expr )
482494 return expr_list
483495
@@ -503,37 +515,56 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503515 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
504516
505517 def aggregate (
506- self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
518+ self ,
519+ group_by : list [Expr | str ] | Expr | str ,
520+ aggs : list [Expr ] | Expr ,
507521 ) -> DataFrame :
508522 """Aggregates the rows of the current DataFrame.
509523
510524 Args:
511- group_by: List of expressions to group by.
525+ group_by: List of expressions or column names to group by.
512526 aggs: List of expressions to aggregate.
513527
514528 Returns:
515529 DataFrame after aggregation.
516530 """
517- group_by = group_by if isinstance (group_by , list ) else [group_by ]
518- aggs = aggs if isinstance (aggs , list ) else [aggs ]
531+ group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
532+ aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
519533
520- group_by = [e .expr for e in group_by ]
521- aggs = [e .expr for e in aggs ]
522- return DataFrame (self .df .aggregate (group_by , aggs ))
534+ group_by_exprs = expr_list_to_raw_expr_list (group_by_list )
535+ aggs_exprs = []
536+ for agg in aggs_list :
537+ if not isinstance (agg , Expr ):
538+ raise TypeError (_EXPR_TYPE_ERROR )
539+ aggs_exprs .append (agg .expr )
540+ return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
523541
524- def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525- """Sort the DataFrame by the specified sorting expressions.
542+ def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
543+ """Sort the DataFrame by the specified sorting expressions or column names .
526544
527545 Note that any expression can be turned into a sort expression by
528- calling its` ``sort`` method.
546+ calling its ``sort`` method.
529547
530548 Args:
531- exprs: Sort expressions, applied in order.
549+ exprs: Sort expressions or column names , applied in order.
532550
533551 Returns:
534552 DataFrame after sorting.
535553 """
536- exprs_raw = [sort_or_default (expr ) for expr in exprs ]
554+ exprs_raw = []
555+ for e in exprs :
556+ if isinstance (e , SortExpr ):
557+ exprs_raw .append (sort_or_default (e ))
558+ elif isinstance (e , str ):
559+ exprs_raw .append (sort_or_default (Expr .column (e )))
560+ elif isinstance (e , Expr ):
561+ exprs_raw .append (sort_or_default (e ))
562+ else :
563+ error = (
564+ "Expected Expr or column name, found:"
565+ f" { type (e ).__name__ } . { _EXPR_TYPE_ERROR } ."
566+ )
567+ raise TypeError (error )
537568 return DataFrame (self .df .sort (* exprs_raw ))
538569
539570 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -757,7 +788,11 @@ def join_on(
757788 Returns:
758789 DataFrame after join.
759790 """
760- exprs = [expr .expr for expr in on_exprs ]
791+ exprs = []
792+ for expr in on_exprs :
793+ if not isinstance (expr , Expr ):
794+ raise TypeError (_EXPR_TYPE_ERROR )
795+ exprs .append (expr .expr )
761796 return DataFrame (self .df .join_on (right .df , exprs , how ))
762797
763798 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments