4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , sort_or_default
43+ from datafusion .expr import Expr , SortExpr , _to_expr_list , sort_or_default
4444from datafusion .plan import ExecutionPlan , LogicalPlan
4545from datafusion .record_batch import RecordBatchStream
4646
@@ -394,9 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396 """
397- exprs_internal = [
398- Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399- ]
397+ exprs_internal = _to_expr_list (exprs )
400398 return DataFrame (self .df .select (* exprs_internal ))
401399
402400 def drop (self , * columns : str ) -> DataFrame :
@@ -426,6 +424,12 @@ def filter(self, *predicates: Expr) -> DataFrame:
426424 """
427425 df = self .df
428426 for p in predicates :
427+ if not isinstance (p , Expr ):
428+ msg = (
429+ f"Expected Expr, got { type (p ).__name__ } . "
430+ "Use col() or lit() to construct expressions."
431+ )
432+ raise TypeError (msg )
429433 df = df .filter (p .expr )
430434 return DataFrame (df )
431435
@@ -439,6 +443,12 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439443 Returns:
440444 DataFrame with the new column.
441445 """
446+ if not isinstance (expr , Expr ):
447+ msg = (
448+ f"Expected Expr, got { type (expr ).__name__ } . "
449+ "Use col() or lit() to construct expressions."
450+ )
451+ raise TypeError (msg )
442452 return DataFrame (self .df .with_column (name , expr .expr ))
443453
444454 def with_columns (
@@ -473,11 +483,28 @@ def _simplify_expression(
473483 if isinstance (expr , Expr ):
474484 expr_list .append (expr .expr )
475485 elif isinstance (expr , Iterable ):
476- expr_list .extend (inner_expr .expr for inner_expr in expr )
486+ for inner_expr in expr :
487+ if not isinstance (inner_expr , Expr ):
488+ msg = (
489+ f"Expected Expr, got { type (inner_expr ).__name__ } . "
490+ "Use col() or lit() to construct expressions."
491+ )
492+ raise TypeError (msg )
493+ expr_list .append (inner_expr .expr )
477494 else :
478- raise NotImplementedError
495+ msg = (
496+ f"Expected Expr, got { type (expr ).__name__ } . "
497+ "Use col() or lit() to construct expressions."
498+ )
499+ raise TypeError (msg )
479500 if named_exprs :
480501 for alias , expr in named_exprs .items ():
502+ if not isinstance (expr , Expr ):
503+ msg = (
504+ f"Expected Expr, got { type (expr ).__name__ } . "
505+ "Use col() or lit() to construct expressions."
506+ )
507+ raise TypeError (msg )
481508 expr_list .append (expr .alias (alias ).expr )
482509 return expr_list
483510
@@ -503,37 +530,56 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503530 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
504531
505532 def aggregate (
506- self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
533+ self ,
534+ group_by : list [Expr | str ] | Expr | str ,
535+ aggs : list [Expr ] | Expr ,
507536 ) -> DataFrame :
508537 """Aggregates the rows of the current DataFrame.
509538
510539 Args:
511- group_by: List of expressions to group by.
540+ group_by: List of expressions or column names to group by.
512541 aggs: List of expressions to aggregate.
513542
514543 Returns:
515544 DataFrame after aggregation.
516545 """
517- group_by = group_by if isinstance (group_by , list ) else [group_by ]
518- aggs = aggs if isinstance (aggs , list ) else [aggs ]
546+ group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
547+ aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
519548
520- group_by = [e .expr for e in group_by ]
521- aggs = [e .expr for e in aggs ]
522- return DataFrame (self .df .aggregate (group_by , aggs ))
549+ group_by_exprs = [
550+ Expr .column (e ).expr if isinstance (e , str ) else e .expr for e in group_by_list
551+ ]
552+ aggs_exprs = []
553+ for agg in aggs_list :
554+ if not isinstance (agg , Expr ):
555+ msg = (
556+ f"Expected Expr, got { type (agg ).__name__ } . "
557+ "Use col() or lit() to construct expressions."
558+ )
559+ raise TypeError (msg )
560+ aggs_exprs .append (agg .expr )
561+ return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
523562
524- def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525- """Sort the DataFrame by the specified sorting expressions.
563+ def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
564+ """Sort the DataFrame by the specified sorting expressions or column names .
526565
527566 Note that any expression can be turned into a sort expression by
528- calling its` ``sort`` method.
567+ calling its ``sort`` method.
529568
530569 Args:
531- exprs: Sort expressions, applied in order.
570+ exprs: Sort expressions or column names , applied in order.
532571
533572 Returns:
534573 DataFrame after sorting.
535574 """
536- exprs_raw = [sort_or_default (expr ) for expr in exprs ]
575+ expr_seq = [e for e in exprs if not isinstance (e , SortExpr )]
576+ raw_exprs_iter = iter (_to_expr_list (expr_seq ))
577+ exprs_raw = []
578+ for e in exprs :
579+ if isinstance (e , SortExpr ):
580+ exprs_raw .append (sort_or_default (e ))
581+ else :
582+ exprs_raw .append (sort_or_default (Expr (next (raw_exprs_iter ))))
537583 return DataFrame (self .df .sort (* exprs_raw ))
538584
539585 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -757,7 +803,15 @@ def join_on(
757803 Returns:
758804 DataFrame after join.
759805 """
760- exprs = [expr .expr for expr in on_exprs ]
806+ exprs = []
807+ for expr in on_exprs :
808+ if not isinstance (expr , Expr ):
809+ msg = (
810+ f"Expected Expr, got { type (expr ).__name__ } . "
811+ "Use col() or lit() to construct expressions."
812+ )
813+ raise TypeError (msg )
814+ exprs .append (expr .expr )
761815 return DataFrame (self .df .join_on (right .df , exprs , how ))
762816
763817 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments