4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import (
44- Expr ,
45- SortExpr ,
46- _ensure_expr ,
47- _to_expr_list ,
48- sort_or_default ,
49- )
43+ from datafusion .expr import Expr , SortExpr , expr_list_to_raw_expr_list , sort_or_default
5044from datafusion .plan import ExecutionPlan , LogicalPlan
5145from datafusion .record_batch import RecordBatchStream
5246
@@ -400,19 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
400394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
401395
402396 """
403- checked_exprs : list [Expr | str ] = []
404- for expr in exprs :
405- if isinstance (expr , SortExpr ):
406- checked_exprs .append (expr .expr ())
407- elif isinstance (expr , (Expr , str )):
408- checked_exprs .append (expr )
409- else :
410- msg = (
411- f"Expected Expr or column name, got { type (expr ).__name__ } . "
412- "Use col() or lit() to construct expressions."
413- )
414- raise TypeError (msg )
415- exprs_internal = _to_expr_list (checked_exprs )
397+ exprs_internal = expr_list_to_raw_expr_list (exprs )
416398 return DataFrame (self .df .select (* exprs_internal ))
417399
418400 def drop (self , * columns : str ) -> DataFrame :
@@ -442,7 +424,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
442424 """
443425 df = self .df
444426 for p in predicates :
445- df = df .filter (_ensure_expr (p ))
427+ df = df .filter (expr_list_to_raw_expr_list (p )[ 0 ] )
446428 return DataFrame (df )
447429
448430 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -455,7 +437,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
455437 Returns:
456438 DataFrame with the new column.
457439 """
458- return DataFrame (self .df .with_column (name , _ensure_expr (expr )))
440+ return DataFrame (self .df .with_column (name , expr_list_to_raw_expr_list (expr )[ 0 ] ))
459441
460442 def with_columns (
461443 self , * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
@@ -487,12 +469,12 @@ def _simplify_expression(
487469 expr_list = []
488470 for expr in exprs :
489471 if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
490- expr_list .extend (_ensure_expr (inner_expr ) for inner_expr in expr )
472+ expr_list .extend (expr_list_to_raw_expr_list (inner_expr )[ 0 ] for inner_expr in expr )
491473 else :
492- expr_list .append (_ensure_expr (expr ))
474+ expr_list .append (expr_list_to_raw_expr_list (expr )[ 0 ] )
493475 if named_exprs :
494476 for alias , expr in named_exprs .items ():
495- _ensure_expr (expr )
477+ expr_list_to_raw_expr_list (expr )[ 0 ]
496478 expr_list .append (expr .alias (alias ).expr )
497479 return expr_list
498480
@@ -537,7 +519,7 @@ def aggregate(
537519 group_by_exprs = [
538520 Expr .column (e ).expr if isinstance (e , str ) else e .expr for e in group_by_list
539521 ]
540- aggs_exprs = [_ensure_expr (agg ) for agg in aggs_list ]
522+ aggs_exprs = [expr_list_to_raw_expr_list (agg )[ 0 ] for agg in aggs_list ]
541523 return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
542524
543525 def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
@@ -552,20 +534,9 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
552534 Returns:
553535 DataFrame after sorting.
554536 """
555- expr_seq : list [Expr | str ] = []
556- for e in exprs :
557- if isinstance (e , SortExpr ):
558- continue
559- if isinstance (e , (Expr , str )):
560- expr_seq .append (e )
561- else :
562- msg = (
563- f"Expected Expr or column name, got { type (e ).__name__ } . "
564- "Use col() or lit() to construct expressions."
565- )
566- raise TypeError (msg )
567- raw_exprs_iter = iter (_to_expr_list (expr_seq ))
568- exprs_raw : list [Any ] = []
537+ expr_seq = [e for e in exprs if not isinstance (e , SortExpr )]
538+ raw_exprs_iter = iter (expr_list_to_raw_expr_list (expr_seq ))
539+ exprs_raw = []
569540 for e in exprs :
570541 if isinstance (e , SortExpr ):
571542 exprs_raw .append (sort_or_default (e ))
@@ -794,7 +765,7 @@ def join_on(
794765 Returns:
795766 DataFrame after join.
796767 """
797- exprs = [_ensure_expr (expr ) for expr in on_exprs ]
768+ exprs = [expr_list_to_raw_expr_list (expr )[ 0 ] for expr in on_exprs ]
798769 return DataFrame (self .df .join_on (right .df , exprs , how ))
799770
800771 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments