Skip to content

Commit 3a364d5

Browse files
timsaucerclaude
andcommitted
Add missing scalar functions: get_field, union_extract, union_tag, arrow_metadata, version, row
Expose upstream DataFusion scalar functions that were not yet available in the Python API. Closes #1453. - get_field: extracts a field from a struct or map by name - union_extract: extracts a value from a union type by field name - union_tag: returns the active field name of a union type - arrow_metadata: returns Arrow field metadata (all or by key) - version: returns the DataFusion version string - row: alias for the struct constructor Note: arrow_try_cast was listed in the issue but does not exist in DataFusion 53, so it is not included. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 73a9d53 commit 3a364d5

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

crates/core/src/functions.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,29 @@ expr_fn_vec!(named_struct);
631631
expr_fn!(from_unixtime, unixtime);
632632
expr_fn!(arrow_typeof, arg_1);
633633
expr_fn!(arrow_cast, arg_1 datatype);
634+
expr_fn_vec!(arrow_metadata);
635+
expr_fn!(union_tag, arg1);
634636
expr_fn!(random);
635637

638+
#[pyfunction]
639+
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
640+
functions::core::get_field()
641+
.call(vec![expr.into(), name.into()])
642+
.into()
643+
}
644+
645+
#[pyfunction]
646+
fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr {
647+
functions::core::union_extract()
648+
.call(vec![union_expr.into(), field_name.into()])
649+
.into()
650+
}
651+
652+
#[pyfunction]
653+
fn version() -> PyExpr {
654+
functions::core::version().call(vec![]).into()
655+
}
656+
636657
// Array Functions
637658
array_fn!(array_append, array element);
638659
array_fn!(array_to_string, array delimiter);
@@ -940,6 +961,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
940961
m.add_wrapped(wrap_pyfunction!(array_agg))?;
941962
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
942963
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
964+
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
943965
m.add_wrapped(wrap_pyfunction!(ascii))?;
944966
m.add_wrapped(wrap_pyfunction!(asin))?;
945967
m.add_wrapped(wrap_pyfunction!(asinh))?;
@@ -1063,6 +1085,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
10631085
m.add_wrapped(wrap_pyfunction!(trim))?;
10641086
m.add_wrapped(wrap_pyfunction!(trunc))?;
10651087
m.add_wrapped(wrap_pyfunction!(upper))?;
1088+
m.add_wrapped(wrap_pyfunction!(get_field))?;
1089+
m.add_wrapped(wrap_pyfunction!(union_extract))?;
1090+
m.add_wrapped(wrap_pyfunction!(union_tag))?;
1091+
m.add_wrapped(wrap_pyfunction!(version))?;
10661092
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
10671093
m.add_wrapped(wrap_pyfunction!(var_pop))?;
10681094
m.add_wrapped(wrap_pyfunction!(var_sample))?;

python/datafusion/functions.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"array_to_string",
9191
"array_union",
9292
"arrow_cast",
93+
"arrow_metadata",
9394
"arrow_typeof",
9495
"ascii",
9596
"asin",
@@ -149,6 +150,7 @@
149150
"floor",
150151
"from_unixtime",
151152
"gcd",
153+
"get_field",
152154
"in_list",
153155
"initcap",
154156
"isnan",
@@ -242,6 +244,7 @@
242244
"reverse",
243245
"right",
244246
"round",
247+
"row",
245248
"row_number",
246249
"rpad",
247250
"rtrim",
@@ -282,12 +285,15 @@
282285
"translate",
283286
"trim",
284287
"trunc",
288+
"union_extract",
289+
"union_tag",
285290
"upper",
286291
"uuid",
287292
"var",
288293
"var_pop",
289294
"var_samp",
290295
"var_sample",
296+
"version",
291297
"when",
292298
# Window Functions
293299
"window",
@@ -2492,6 +2498,86 @@ def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
24922498
return Expr(f.arrow_cast(expr.expr, data_type.expr))
24932499

24942500

2501+
def arrow_metadata(*args: Expr) -> Expr:
2502+
"""Returns the metadata of the input expression.
2503+
2504+
If called with one argument, returns a Map of all metadata key-value pairs.
2505+
If called with two arguments, returns the value for the specified metadata key.
2506+
2507+
Args:
2508+
args: An expression, optionally followed by a metadata key string.
2509+
2510+
Returns:
2511+
A Map of metadata or a specific metadata value.
2512+
"""
2513+
args = [arg.expr for arg in args]
2514+
return Expr(f.arrow_metadata(*args))
2515+
2516+
2517+
def get_field(expr: Expr, name: Expr) -> Expr:
2518+
"""Extracts a field from a struct or map by name.
2519+
2520+
Args:
2521+
expr: A struct or map expression.
2522+
name: The field name to extract.
2523+
2524+
Returns:
2525+
The value of the named field.
2526+
"""
2527+
return Expr(f.get_field(expr.expr, name.expr))
2528+
2529+
2530+
def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
2531+
"""Extracts a value from a union type by field name.
2532+
2533+
Returns the value of the named field if it is the currently selected
2534+
variant, otherwise returns NULL.
2535+
2536+
Args:
2537+
union_expr: A union-typed expression.
2538+
field_name: The name of the field to extract.
2539+
2540+
Returns:
2541+
The extracted value or NULL.
2542+
"""
2543+
return Expr(f.union_extract(union_expr.expr, field_name.expr))
2544+
2545+
2546+
def union_tag(union_expr: Expr) -> Expr:
2547+
"""Returns the tag (active field name) of a union type.
2548+
2549+
Args:
2550+
union_expr: A union-typed expression.
2551+
2552+
Returns:
2553+
The name of the currently selected field in the union.
2554+
"""
2555+
return Expr(f.union_tag(union_expr.expr))
2556+
2557+
2558+
def version() -> Expr:
2559+
"""Returns the DataFusion version string.
2560+
2561+
Returns:
2562+
A string describing the DataFusion version.
2563+
"""
2564+
return Expr(f.version())
2565+
2566+
2567+
def row(*args: Expr) -> Expr:
2568+
"""Returns a struct with the given arguments.
2569+
2570+
This is an alias for :py:func:`struct`.
2571+
2572+
Args:
2573+
args: The expressions to include in the struct.
2574+
2575+
Returns:
2576+
A struct expression.
2577+
"""
2578+
return struct(*args)
2579+
2580+
24952581
def random() -> Expr:
24962582
"""Returns a random value in the range ``0.0 <= x < 1.0``.
24972583

0 commit comments

Comments
 (0)