Skip to content

Commit d405684

Browse files
timsaucerclaude
andcommitted
Add missing scalar functions: get_field, union_extract, union_tag, arrow_metadata, version, row
Expose upstream DataFusion scalar functions that were not yet available in the Python API. Closes #1453. - get_field: extracts a field from a struct or map by name - union_extract: extracts a value from a union type by field name - union_tag: returns the active field name of a union type - arrow_metadata: returns Arrow field metadata (all or by key) - version: returns the DataFusion version string - row: alias for the struct constructor Note: arrow_try_cast was listed in the issue but does not exist in DataFusion 53, so it is not included. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 645d261 commit d405684

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

crates/core/src/functions.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,29 @@ expr_fn_vec!(named_struct);
637637
expr_fn!(from_unixtime, unixtime);
638638
expr_fn!(arrow_typeof, arg_1);
639639
expr_fn!(arrow_cast, arg_1 datatype);
640+
expr_fn_vec!(arrow_metadata);
641+
expr_fn!(union_tag, arg1);
640642
expr_fn!(random);
641643

644+
#[pyfunction]
645+
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
646+
functions::core::get_field()
647+
.call(vec![expr.into(), name.into()])
648+
.into()
649+
}
650+
651+
#[pyfunction]
652+
fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr {
653+
functions::core::union_extract()
654+
.call(vec![union_expr.into(), field_name.into()])
655+
.into()
656+
}
657+
658+
#[pyfunction]
659+
fn version() -> PyExpr {
660+
functions::core::version().call(vec![]).into()
661+
}
662+
642663
// Array Functions
643664
array_fn!(array_append, array element);
644665
array_fn!(array_to_string, array delimiter);
@@ -946,6 +967,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
946967
m.add_wrapped(wrap_pyfunction!(array_agg))?;
947968
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
948969
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
970+
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
949971
m.add_wrapped(wrap_pyfunction!(ascii))?;
950972
m.add_wrapped(wrap_pyfunction!(asin))?;
951973
m.add_wrapped(wrap_pyfunction!(asinh))?;
@@ -1071,6 +1093,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
10711093
m.add_wrapped(wrap_pyfunction!(trim))?;
10721094
m.add_wrapped(wrap_pyfunction!(trunc))?;
10731095
m.add_wrapped(wrap_pyfunction!(upper))?;
1096+
m.add_wrapped(wrap_pyfunction!(get_field))?;
1097+
m.add_wrapped(wrap_pyfunction!(union_extract))?;
1098+
m.add_wrapped(wrap_pyfunction!(union_tag))?;
1099+
m.add_wrapped(wrap_pyfunction!(version))?;
10741100
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
10751101
m.add_wrapped(wrap_pyfunction!(var_pop))?;
10761102
m.add_wrapped(wrap_pyfunction!(var_sample))?;

python/datafusion/functions.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"array_to_string",
9191
"array_union",
9292
"arrow_cast",
93+
"arrow_metadata",
9394
"arrow_typeof",
9495
"ascii",
9596
"asin",
@@ -152,6 +153,7 @@
152153
"floor",
153154
"from_unixtime",
154155
"gcd",
156+
"get_field",
155157
"in_list",
156158
"initcap",
157159
"isnan",
@@ -246,6 +248,7 @@
246248
"reverse",
247249
"right",
248250
"round",
251+
"row",
249252
"row_number",
250253
"rpad",
251254
"rtrim",
@@ -286,12 +289,15 @@
286289
"translate",
287290
"trim",
288291
"trunc",
292+
"union_extract",
293+
"union_tag",
289294
"upper",
290295
"uuid",
291296
"var",
292297
"var_pop",
293298
"var_samp",
294299
"var_sample",
300+
"version",
295301
"when",
296302
# Window Functions
297303
"window",
@@ -2543,6 +2549,86 @@ def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
25432549
return Expr(f.arrow_cast(expr.expr, data_type.expr))
25442550

25452551

2552+
def arrow_metadata(*args: Expr) -> Expr:
2553+
"""Returns the metadata of the input expression.
2554+
2555+
If called with one argument, returns a Map of all metadata key-value pairs.
2556+
If called with two arguments, returns the value for the specified metadata key.
2557+
2558+
Args:
2559+
args: An expression, optionally followed by a metadata key string.
2560+
2561+
Returns:
2562+
A Map of metadata or a specific metadata value.
2563+
"""
2564+
args = [arg.expr for arg in args]
2565+
return Expr(f.arrow_metadata(*args))
2566+
2567+
2568+
def get_field(expr: Expr, name: Expr) -> Expr:
2569+
"""Extracts a field from a struct or map by name.
2570+
2571+
Args:
2572+
expr: A struct or map expression.
2573+
name: The field name to extract.
2574+
2575+
Returns:
2576+
The value of the named field.
2577+
"""
2578+
return Expr(f.get_field(expr.expr, name.expr))
2579+
2580+
2581+
def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
2582+
"""Extracts a value from a union type by field name.
2583+
2584+
Returns the value of the named field if it is the currently selected
2585+
variant, otherwise returns NULL.
2586+
2587+
Args:
2588+
union_expr: A union-typed expression.
2589+
field_name: The name of the field to extract.
2590+
2591+
Returns:
2592+
The extracted value or NULL.
2593+
"""
2594+
return Expr(f.union_extract(union_expr.expr, field_name.expr))
2595+
2596+
2597+
def union_tag(union_expr: Expr) -> Expr:
2598+
"""Returns the tag (active field name) of a union type.
2599+
2600+
Args:
2601+
union_expr: A union-typed expression.
2602+
2603+
Returns:
2604+
The name of the currently selected field in the union.
2605+
"""
2606+
return Expr(f.union_tag(union_expr.expr))
2607+
2608+
2609+
def version() -> Expr:
2610+
"""Returns the DataFusion version string.
2611+
2612+
Returns:
2613+
A string describing the DataFusion version.
2614+
"""
2615+
return Expr(f.version())
2616+
2617+
2618+
def row(*args: Expr) -> Expr:
2619+
"""Returns a struct with the given arguments.
2620+
2621+
This is an alias for :py:func:`struct`.
2622+
2623+
Args:
2624+
args: The expressions to include in the struct.
2625+
2626+
Returns:
2627+
A struct expression.
2628+
"""
2629+
return struct(*args)
2630+
2631+
25462632
def random() -> Expr:
25472633
"""Returns a random value in the range ``0.0 <= x < 1.0``.
25482634

0 commit comments

Comments
 (0)