4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4343from datafusion .expr import (
44- EXPR_TYPE_ERROR ,
44+ _EXPR_TYPE_ERROR ,
4545 Expr ,
4646 SortExpr ,
4747 expr_list_to_raw_expr_list ,
@@ -430,9 +430,9 @@ def filter(self, *predicates: Expr) -> DataFrame:
430430 """
431431 df = self .df
432432 for p in predicates :
433- if isinstance ( p , str ) or not isinstance (p , Expr ):
434- raise TypeError (EXPR_TYPE_ERROR )
435- df = df .filter (expr_list_to_raw_expr_list ( p )[ 0 ] )
433+ if not isinstance (p , Expr ):
434+ raise TypeError (_EXPR_TYPE_ERROR )
435+ df = df .filter (p . expr )
436436 return DataFrame (df )
437437
438438 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -446,7 +446,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
446446 DataFrame with the new column.
447447 """
448448 if not isinstance (expr , Expr ):
449- raise TypeError (EXPR_TYPE_ERROR )
449+ raise TypeError (_EXPR_TYPE_ERROR )
450450 return DataFrame (self .df .with_column (name , expr .expr ))
451451
452452 def with_columns (
@@ -478,19 +478,20 @@ def _simplify_expression(
478478 ) -> list [expr_internal .Expr ]:
479479 expr_list : list [expr_internal .Expr ] = []
480480 for expr in exprs :
481- if isinstance (expr , str ):
482- raise TypeError (EXPR_TYPE_ERROR )
483- if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
484- if any (not isinstance (inner_expr , Expr ) for inner_expr in expr ):
485- raise TypeError (EXPR_TYPE_ERROR )
486- elif not isinstance (expr , Expr ):
487- raise TypeError (EXPR_TYPE_ERROR )
488- expr_list .extend (expr_list_to_raw_expr_list (expr ))
489- if named_exprs :
490- for alias , expr in named_exprs .items ():
491- if not isinstance (expr , Expr ):
492- raise TypeError (EXPR_TYPE_ERROR )
493- expr_list .append (expr .alias (alias ).expr )
481+ if isinstance (expr , str ) or (
482+ isinstance (expr , Iterable )
483+ and not isinstance (expr , Expr )
484+ and any (isinstance (inner , str ) for inner in expr )
485+ ):
486+ raise TypeError (_EXPR_TYPE_ERROR )
487+ try :
488+ expr_list .extend (expr_list_to_raw_expr_list (expr ))
489+ except TypeError as err :
490+ raise TypeError (_EXPR_TYPE_ERROR ) from err
491+ for alias , expr in named_exprs .items ():
492+ if not isinstance (expr , Expr ):
493+ raise TypeError (_EXPR_TYPE_ERROR )
494+ expr_list .append (expr .alias (alias ).expr )
494495 return expr_list
495496
496497 expressions = _simplify_expression (* exprs , ** named_exprs )
@@ -535,7 +536,7 @@ def aggregate(
535536 aggs_exprs = []
536537 for agg in aggs_list :
537538 if not isinstance (agg , Expr ):
538- raise TypeError (EXPR_TYPE_ERROR )
539+ raise TypeError (_EXPR_TYPE_ERROR )
539540 aggs_exprs .append (agg .expr )
540541 return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
541542
@@ -551,20 +552,7 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
551552 Returns:
552553 DataFrame after sorting.
553554 """
554- exprs_raw = []
555- for e in exprs :
556- if isinstance (e , SortExpr ):
557- exprs_raw .append (sort_or_default (e ))
558- elif isinstance (e , str ):
559- exprs_raw .append (sort_or_default (Expr .column (e )))
560- elif isinstance (e , Expr ):
561- exprs_raw .append (sort_or_default (e ))
562- else :
563- error = (
564- "Expected Expr or column name, found:"
565- f" { type (e ).__name__ } . { EXPR_TYPE_ERROR } ."
566- )
567- raise TypeError (error )
555+ exprs_raw = sort_list_to_raw_sort_list (list (exprs ))
568556 return DataFrame (self .df .sort (* exprs_raw ))
569557
570558 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -791,7 +779,7 @@ def join_on(
791779 exprs = []
792780 for expr in on_exprs :
793781 if not isinstance (expr , Expr ):
794- raise TypeError (EXPR_TYPE_ERROR )
782+ raise TypeError (_EXPR_TYPE_ERROR )
795783 exprs .append (expr .expr )
796784 return DataFrame (self .df .join_on (right .df , exprs , how ))
797785
0 commit comments