Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,29 @@ expr_fn_vec!(named_struct);
expr_fn!(from_unixtime, unixtime);
expr_fn!(arrow_typeof, arg_1);
expr_fn!(arrow_cast, arg_1 datatype);
expr_fn_vec!(arrow_metadata);
expr_fn!(union_tag, arg1);
expr_fn!(random);

#[pyfunction]
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
functions::core::get_field()
.call(vec![expr.into(), name.into()])
.into()
}

#[pyfunction]
fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr {
functions::core::union_extract()
.call(vec![union_expr.into(), field_name.into()])
.into()
}

#[pyfunction]
fn version() -> PyExpr {
functions::core::version().call(vec![]).into()
}

// Array Functions
array_fn!(array_append, array element);
array_fn!(array_to_string, array delimiter);
Expand Down Expand Up @@ -953,6 +974,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
m.add_wrapped(wrap_pyfunction!(asinh))?;
Expand Down Expand Up @@ -1081,6 +1103,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(trim))?;
m.add_wrapped(wrap_pyfunction!(trunc))?;
m.add_wrapped(wrap_pyfunction!(upper))?;
m.add_wrapped(wrap_pyfunction!(get_field))?;
m.add_wrapped(wrap_pyfunction!(union_extract))?;
m.add_wrapped(wrap_pyfunction!(union_tag))?;
m.add_wrapped(wrap_pyfunction!(version))?;
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
m.add_wrapped(wrap_pyfunction!(var_pop))?;
m.add_wrapped(wrap_pyfunction!(var_sample))?;
Expand Down
101 changes: 98 additions & 3 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"array_to_string",
"array_union",
"arrow_cast",
"arrow_metadata",
"arrow_typeof",
"ascii",
"asin",
Expand Down Expand Up @@ -152,6 +153,7 @@
"floor",
"from_unixtime",
"gcd",
"get_field",
"greatest",
"ifnull",
"in_list",
Expand Down Expand Up @@ -250,6 +252,7 @@
"reverse",
"right",
"round",
"row",
"row_number",
"rpad",
"rtrim",
Expand Down Expand Up @@ -290,12 +293,15 @@
"translate",
"trim",
"trunc",
"union_extract",
"union_tag",
"upper",
"uuid",
"var",
"var_pop",
"var_samp",
"var_sample",
"version",
"when",
# Window Functions
"window",
Expand Down Expand Up @@ -2596,22 +2602,111 @@ def arrow_typeof(arg: Expr) -> Expr:
return Expr(f.arrow_typeof(arg.expr))


def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr:
"""Casts an expression to a specified data type.

Examples:
>>> ctx = dfn.SessionContext()
Comment on lines +2605 to 2609
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR declares Closes #1453, but the issue also lists arrow_try_cast as a missing scalar function. I verified there is no arrow_try_cast wrapper anywhere in the repo (no Python wrapper in python/datafusion/functions.py and no Rust binding in crates/core/src/functions.rs). Either add arrow_try_cast (and a corresponding unit test) or adjust the PR description/linked issue closure so we’re not closing the issue prematurely.

Copilot uses AI. Check for mistakes.
>>> df = ctx.from_pydict({"a": [1]})
>>> data_type = dfn.string_literal("Float64")
>>> result = df.select(
... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c")
... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c")
... )
>>> result.collect_column("c")[0].as_py()
1.0
"""
if isinstance(data_type, str):
data_type = Expr.string_literal(data_type)
return Expr(f.arrow_cast(expr.expr, data_type.expr))


def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
"""Returns the metadata of the input expression.

If called with one argument, returns a Map of all metadata key-value pairs.
If called with two arguments, returns the value for the specified metadata key.

Args:
expr: An expression whose metadata to retrieve.
key: Optional metadata key to look up. Can be a string or an Expr.

Returns:
A Map of metadata or a specific metadata value.
"""
if key is None:
return Expr(f.arrow_metadata(expr.expr))
if isinstance(key, str):
key = Expr.string_literal(key)
return Expr(f.arrow_metadata(expr.expr, key.expr))


def get_field(expr: Expr, name: Expr | str) -> Expr:
"""Extracts a field from a struct or map by name.

Args:
expr: A struct or map expression.
name: The field name to extract.

Returns:
The value of the named field.
"""
if isinstance(name, str):
name = Expr.string_literal(name)
return Expr(f.get_field(expr.expr, name.expr))


def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
"""Extracts a value from a union type by field name.

Returns the value of the named field if it is the currently selected
variant, otherwise returns NULL.

Args:
union_expr: A union-typed expression.
field_name: The name of the field to extract.

Returns:
The extracted value or NULL.
"""
if isinstance(field_name, str):
field_name = Expr.string_literal(field_name)
return Expr(f.union_extract(union_expr.expr, field_name.expr))


def union_tag(union_expr: Expr) -> Expr:
"""Returns the tag (active field name) of a union type.

Args:
union_expr: A union-typed expression.

Returns:
The name of the currently selected field in the union.
"""
return Expr(f.union_tag(union_expr.expr))


def version() -> Expr:
"""Returns the DataFusion version string.

Returns:
A string describing the DataFusion version.
"""
return Expr(f.version())


def row(*args: Expr) -> Expr:
"""Returns a struct with the given arguments.

This is an alias for :py:func:`struct`.

Args:
args: The expressions to include in the struct.

Returns:
A struct expression.
"""
return struct(*args)


def random() -> Expr:
"""Returns a random value in the range ``0.0 <= x < 1.0``.

Expand Down
92 changes: 86 additions & 6 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import pyarrow as pa
import pytest
from datafusion import SessionContext, column, literal, string_literal
from datafusion import SessionContext, column, literal
from datafusion import functions as f

np.seterr(invalid="ignore")
Expand Down Expand Up @@ -1143,11 +1143,8 @@ def test_make_time(df):

def test_arrow_cast(df):
df = df.select(
# we use `string_literal` to return utf8 instead of `literal` which returns
# utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view
# https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"),
f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
f.arrow_cast(column("b"), "Float64").alias("b_as_float"),
f.arrow_cast(column("b"), "Int32").alias("b_as_int"),
)
result = df.collect()
assert len(result) == 1
Expand Down Expand Up @@ -1660,3 +1657,86 @@ def df_with_nulls():
def test_conditional_functions(df_with_nulls, expr, expected):
result = df_with_nulls.select(expr.alias("result")).collect()[0]
assert result.column(0) == expected


def test_get_field(df):
df = df.with_column(
"s",
f.named_struct(
[
("x", column("a")),
("y", column("b")),
]
),
)
result = df.select(
f.get_field(column("s"), "x").alias("x_val"),
f.get_field(column("s"), "y").alias("y_val"),
).collect()[0]

assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
assert result.column(1) == pa.array([4, 5, 6])


def test_arrow_metadata():
ctx = SessionContext()
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})
schema = pa.schema([field])
batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema)
df = ctx.create_dataframe([[batch]])

# One-argument form: returns a Map of all metadata key-value pairs
result = df.select(
f.arrow_metadata(column("val")).alias("meta"),
).collect()[0]
assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8())
meta = result.column(0)[0].as_py()
assert ("key1", "value1") in meta
assert ("key2", "value2") in meta

# Two-argument form: returns the value for a specific metadata key
result = df.select(
f.arrow_metadata(column("val"), "key1").alias("meta_val"),
).collect()[0]
assert result.column(0)[0].as_py() == "value1"


def test_version():
ctx = SessionContext()
df = ctx.from_pydict({"a": [1]})
result = df.select(f.version().alias("v")).collect()[0]
version_str = result.column(0)[0].as_py()
assert "Apache DataFusion" in version_str


def test_row(df):
result = df.select(
f.row(column("a"), column("b")).alias("r"),
f.struct(column("a"), column("b")).alias("s"),
).collect()[0]
# row is an alias for struct, so they should produce the same output
assert result.column(0) == result.column(1)


def test_union_tag():
ctx = SessionContext()
types = pa.array([0, 1, 0], type=pa.int8())
offsets = pa.array([0, 0, 1], type=pa.int32())
children = [pa.array([1, 2]), pa.array(["hello"])]
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])

result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0]
assert result.column(0).to_pylist() == ["int", "str", "int"]


def test_union_extract():
ctx = SessionContext()
types = pa.array([0, 1, 0], type=pa.int8())
offsets = pa.array([0, 0, 1], type=pa.int32())
children = [pa.array([1, 2]), pa.array(["hello"])]
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])

result = df.select(f.union_extract(column("u"), "int").alias("val")).collect()[0]
assert result.column(0).to_pylist() == [1, None, 2]
Loading