Skip to content

Commit 80f4a6f

Browse files
committed
Refactor expression handling by replacing _ensure_expr and _to_expr_list with expr_list_to_raw_expr_list for improved clarity and consistency in DataFrame methods.
1 parent 60853ed commit 80f4a6f

File tree

2 files changed

+18
-61
lines changed

2 files changed

+18
-61
lines changed

python/datafusion/dataframe.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import (
44-
Expr,
45-
SortExpr,
46-
_ensure_expr,
47-
_to_expr_list,
48-
sort_or_default,
49-
)
43+
from datafusion.expr import Expr, SortExpr, expr_list_to_raw_expr_list, sort_or_default
5044
from datafusion.plan import ExecutionPlan, LogicalPlan
5145
from datafusion.record_batch import RecordBatchStream
5246

@@ -400,19 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
400394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
401395
402396
"""
403-
checked_exprs: list[Expr | str] = []
404-
for expr in exprs:
405-
if isinstance(expr, SortExpr):
406-
checked_exprs.append(expr.expr())
407-
elif isinstance(expr, (Expr, str)):
408-
checked_exprs.append(expr)
409-
else:
410-
msg = (
411-
f"Expected Expr or column name, got {type(expr).__name__}. "
412-
"Use col() or lit() to construct expressions."
413-
)
414-
raise TypeError(msg)
415-
exprs_internal = _to_expr_list(checked_exprs)
397+
exprs_internal = expr_list_to_raw_expr_list(exprs)
416398
return DataFrame(self.df.select(*exprs_internal))
417399

418400
def drop(self, *columns: str) -> DataFrame:
@@ -442,7 +424,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
442424
"""
443425
df = self.df
444426
for p in predicates:
445-
df = df.filter(_ensure_expr(p))
427+
df = df.filter(expr_list_to_raw_expr_list(p)[0])
446428
return DataFrame(df)
447429

448430
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -455,7 +437,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
455437
Returns:
456438
DataFrame with the new column.
457439
"""
458-
return DataFrame(self.df.with_column(name, _ensure_expr(expr)))
440+
return DataFrame(self.df.with_column(name, expr_list_to_raw_expr_list(expr)[0]))
459441

460442
def with_columns(
461443
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
@@ -487,12 +469,12 @@ def _simplify_expression(
487469
expr_list = []
488470
for expr in exprs:
489471
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
490-
expr_list.extend(_ensure_expr(inner_expr) for inner_expr in expr)
472+
expr_list.extend(expr_list_to_raw_expr_list(inner_expr)[0] for inner_expr in expr)
491473
else:
492-
expr_list.append(_ensure_expr(expr))
474+
expr_list.append(expr_list_to_raw_expr_list(expr)[0])
493475
if named_exprs:
494476
for alias, expr in named_exprs.items():
495-
_ensure_expr(expr)
477+
expr_list_to_raw_expr_list(expr)[0]
496478
expr_list.append(expr.alias(alias).expr)
497479
return expr_list
498480

@@ -537,7 +519,7 @@ def aggregate(
537519
group_by_exprs = [
538520
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
539521
]
540-
aggs_exprs = [_ensure_expr(agg) for agg in aggs_list]
522+
aggs_exprs = [expr_list_to_raw_expr_list(agg)[0] for agg in aggs_list]
541523
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
542524

543525
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
@@ -552,20 +534,9 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
552534
Returns:
553535
DataFrame after sorting.
554536
"""
555-
expr_seq: list[Expr | str] = []
556-
for e in exprs:
557-
if isinstance(e, SortExpr):
558-
continue
559-
if isinstance(e, (Expr, str)):
560-
expr_seq.append(e)
561-
else:
562-
msg = (
563-
f"Expected Expr or column name, got {type(e).__name__}. "
564-
"Use col() or lit() to construct expressions."
565-
)
566-
raise TypeError(msg)
567-
raw_exprs_iter = iter(_to_expr_list(expr_seq))
568-
exprs_raw: list[Any] = []
537+
expr_seq = [e for e in exprs if not isinstance(e, SortExpr)]
538+
raw_exprs_iter = iter(expr_list_to_raw_expr_list(expr_seq))
539+
exprs_raw = []
569540
for e in exprs:
570541
if isinstance(e, SortExpr):
571542
exprs_raw.append(sort_or_default(e))
@@ -794,7 +765,7 @@ def join_on(
794765
Returns:
795766
DataFrame after join.
796767
"""
797-
exprs = [_ensure_expr(expr) for expr in on_exprs]
768+
exprs = [expr_list_to_raw_expr_list(expr)[0] for expr in on_exprs]
798769
return DataFrame(self.df.join_on(right.df, exprs, how))
799770

800771
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -215,29 +215,15 @@
215215
]
216216

217217

218-
def _to_expr_list(exprs: Sequence[Expr | str]) -> list[expr_internal.Expr]:
219-
"""Convert a sequence of expressions or column names to raw expressions."""
220-
return [Expr.column(e).expr if isinstance(e, str) else e.expr for e in exprs]
221-
222-
223-
def _ensure_expr(value: Any) -> expr_internal.Expr:
224-
"""Return the internal expression or raise if the value is not an Expr."""
225-
if isinstance(value, Expr):
226-
return value.expr
227-
msg = (
228-
f"Expected Expr, got {type(value).__name__}. "
229-
"Use col() or lit() to construct expressions."
230-
)
231-
raise TypeError(msg)
232-
233-
234218
def expr_list_to_raw_expr_list(
235-
expr_list: Optional[list[Expr] | Expr],
219+
expr_list: Optional[Sequence[Expr | str] | Expr | str],
236220
) -> Optional[list[expr_internal.Expr]]:
237-
"""Helper function to convert an optional list to raw expressions."""
238-
if isinstance(expr_list, Expr):
221+
"""Convert a sequence of expressions or column names to raw expressions."""
222+
if isinstance(expr_list, (Expr, str)):
239223
expr_list = [expr_list]
240-
return [e.expr for e in expr_list] if expr_list is not None else None
224+
if expr_list is None:
225+
return None
226+
return [Expr.column(e).expr if isinstance(e, str) else e.expr for e in expr_list]
241227

242228

243229
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:

0 commit comments

Comments
 (0)