From 148f62e3a4057b133dd6dc6dbc770f19b40e8825 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 31 Mar 2026 12:09:20 -0400 Subject: [PATCH 1/9] Add missing scalar functions: get_field, union_extract, union_tag, arrow_metadata, version, row Expose upstream DataFusion scalar functions that were not yet available in the Python API. Closes #1453. - get_field: extracts a field from a struct or map by name - union_extract: extracts a value from a union type by field name - union_tag: returns the active field name of a union type - arrow_metadata: returns Arrow field metadata (all or by key) - version: returns the DataFusion version string - row: alias for the struct constructor Note: arrow_try_cast was listed in the issue but does not exist in DataFusion 53, so it is not included. Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/core/src/functions.rs | 26 ++++++++++ python/datafusion/functions.py | 86 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 3f07da95b..94070cc95 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -644,8 +644,29 @@ expr_fn_vec!(named_struct); expr_fn!(from_unixtime, unixtime); expr_fn!(arrow_typeof, arg_1); expr_fn!(arrow_cast, arg_1 datatype); +expr_fn_vec!(arrow_metadata); +expr_fn!(union_tag, arg1); expr_fn!(random); +#[pyfunction] +fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr { + functions::core::get_field() + .call(vec![expr.into(), name.into()]) + .into() +} + +#[pyfunction] +fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr { + functions::core::union_extract() + .call(vec![union_expr.into(), field_name.into()]) + .into() +} + +#[pyfunction] +fn version() -> PyExpr { + functions::core::version().call(vec![]).into() +} + // Array Functions array_fn!(array_append, array element); array_fn!(array_to_string, array delimiter); @@ -953,6 +974,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; m.add_wrapped(wrap_pyfunction!(arrow_cast))?; + m.add_wrapped(wrap_pyfunction!(arrow_metadata))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; m.add_wrapped(wrap_pyfunction!(asinh))?; @@ -1081,6 +1103,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; + m.add_wrapped(wrap_pyfunction!(get_field))?; + m.add_wrapped(wrap_pyfunction!(union_extract))?; + m.add_wrapped(wrap_pyfunction!(union_tag))?; + m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision m.add_wrapped(wrap_pyfunction!(var_pop))?; m.add_wrapped(wrap_pyfunction!(var_sample))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f1ea3d256..022e1699d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -90,6 +90,7 @@ "array_to_string", "array_union", "arrow_cast", + "arrow_metadata", "arrow_typeof", "ascii", "asin", @@ -152,6 +153,7 @@ "floor", "from_unixtime", "gcd", + "get_field", "greatest", "ifnull", "in_list", @@ -250,6 +252,7 @@ "reverse", "right", "round", + "row", "row_number", "rpad", "rtrim", @@ -290,12 +293,15 @@ "translate", "trim", "trunc", + "union_extract", + "union_tag", "upper", "uuid", "var", "var_pop", "var_samp", "var_sample", + "version", "when", # Window Functions "window", @@ -2612,6 +2618,86 @@ def arrow_cast(expr: Expr, data_type: Expr) -> Expr: return Expr(f.arrow_cast(expr.expr, data_type.expr)) +def arrow_metadata(*args: Expr) -> Expr: + """Returns the metadata of the input expression. + + If called with one argument, returns a Map of all metadata key-value pairs. + If called with two arguments, returns the value for the specified metadata key. + + Args: + args: An expression, optionally followed by a metadata key string. + + Returns: + A Map of metadata or a specific metadata value. + """ + args = [arg.expr for arg in args] + return Expr(f.arrow_metadata(*args)) + + +def get_field(expr: Expr, name: Expr) -> Expr: + """Extracts a field from a struct or map by name. + + Args: + expr: A struct or map expression. + name: The field name to extract. + + Returns: + The value of the named field. + """ + return Expr(f.get_field(expr.expr, name.expr)) + + +def union_extract(union_expr: Expr, field_name: Expr) -> Expr: + """Extracts a value from a union type by field name. + + Returns the value of the named field if it is the currently selected + variant, otherwise returns NULL. + + Args: + union_expr: A union-typed expression. + field_name: The name of the field to extract. + + Returns: + The extracted value or NULL. + """ + return Expr(f.union_extract(union_expr.expr, field_name.expr)) + + +def union_tag(union_expr: Expr) -> Expr: + """Returns the tag (active field name) of a union type. + + Args: + union_expr: A union-typed expression. + + Returns: + The name of the currently selected field in the union. + """ + return Expr(f.union_tag(union_expr.expr)) + + +def version() -> Expr: + """Returns the DataFusion version string. + + Returns: + A string describing the DataFusion version. + """ + return Expr(f.version()) + + +def row(*args: Expr) -> Expr: + """Returns a struct with the given arguments. + + This is an alias for :py:func:`struct`. + + Args: + args: The expressions to include in the struct. + + Returns: + A struct expression. + """ + return struct(*args) + + def random() -> Expr: """Returns a random value in the range ``0.0 <= x < 1.0``. From ea2370af8df81655cf089513b17355edf850c3f2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 31 Mar 2026 12:14:47 -0400 Subject: [PATCH 2/9] Add tests for new scalar functions Tests for get_field, arrow_metadata, version, row, union_tag, and union_extract. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/tests/test_functions.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 74fcbffb4..a4d7482e4 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1660,3 +1660,73 @@ def df_with_nulls(): def test_conditional_functions(df_with_nulls, expr, expected): result = df_with_nulls.select(expr.alias("result")).collect()[0] assert result.column(0) == expected + + +def test_get_field(df): + df = df.with_column( + "s", + f.named_struct( + [ + ("x", column("a")), + ("y", column("b")), + ] + ), + ) + result = df.select( + f.get_field(column("s"), string_literal("x")).alias("x_val"), + f.get_field(column("s"), string_literal("y")).alias("y_val"), + ).collect()[0] + + assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view()) + assert result.column(1) == pa.array([4, 5, 6]) + + +def test_arrow_metadata(df): + result = df.select( + f.arrow_metadata(column("a")).alias("meta"), + ).collect()[0] + # The metadata column should be returned as a map type (possibly empty) + assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8()) + + +def test_version(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + result = df.select(f.version().alias("v")).collect()[0] + version_str = result.column(0)[0].as_py() + assert "Apache DataFusion" in version_str + + +def test_row(df): + result = df.select( + f.row(column("a"), column("b")).alias("r"), + f.struct(column("a"), column("b")).alias("s"), + ).collect()[0] + # row is an alias for struct, so they should produce the same output + assert result.column(0) == result.column(1) + + +def test_union_tag(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0] + assert result.column(0).to_pylist() == ["int", "str", "int"] + + +def test_union_extract(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select( + f.union_extract(column("u"), string_literal("int")).alias("val") + ).collect()[0] + assert result.column(0).to_pylist() == [1, None, 2] From 02eb25552d5fd3ef51e18cd09b7c5caff90fdf4d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 3 Apr 2026 15:35:55 -0400 Subject: [PATCH 3/9] Accept str for field name and type parameters in scalar functions Allow arrow_cast, get_field, and union_extract to accept plain str arguments instead of requiring Expr wrappers. Also improve arrow_metadata test coverage and fix parameter shadowing. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 18 ++++++++++------- python/tests/test_functions.py | 36 ++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 022e1699d..7ed8d37cc 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2602,19 +2602,20 @@ def arrow_typeof(arg: Expr) -> Expr: return Expr(f.arrow_typeof(arg.expr)) -def arrow_cast(expr: Expr, data_type: Expr) -> Expr: +def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr: """Casts an expression to a specified data type. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) - >>> data_type = dfn.string_literal("Float64") >>> result = df.select( - ... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c") + ... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c") ... ) >>> result.collect_column("c")[0].as_py() 1.0 """ + if isinstance(data_type, str): + data_type = Expr.string_literal(data_type) return Expr(f.arrow_cast(expr.expr, data_type.expr)) @@ -2630,11 +2631,10 @@ def arrow_metadata(*args: Expr) -> Expr: Returns: A Map of metadata or a specific metadata value. """ - args = [arg.expr for arg in args] - return Expr(f.arrow_metadata(*args)) + return Expr(f.arrow_metadata(*[arg.expr for arg in args])) -def get_field(expr: Expr, name: Expr) -> Expr: +def get_field(expr: Expr, name: Expr | str) -> Expr: """Extracts a field from a struct or map by name. Args: @@ -2644,10 +2644,12 @@ def get_field(expr: Expr, name: Expr) -> Expr: Returns: The value of the named field. """ + if isinstance(name, str): + name = Expr.string_literal(name) return Expr(f.get_field(expr.expr, name.expr)) -def union_extract(union_expr: Expr, field_name: Expr) -> Expr: +def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: """Extracts a value from a union type by field name. Returns the value of the named field if it is the currently selected @@ -2660,6 +2662,8 @@ def union_extract(union_expr: Expr, field_name: Expr) -> Expr: Returns: The extracted value or NULL. """ + if isinstance(field_name, str): + field_name = Expr.string_literal(field_name) return Expr(f.union_extract(union_expr.expr, field_name.expr)) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index a4d7482e4..3e22ca7a2 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1143,11 +1143,8 @@ def test_make_time(df): def test_arrow_cast(df): df = df.select( - # we use `string_literal` to return utf8 instead of `literal` which returns - # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view - # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 - f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"), - f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"), + f.arrow_cast(column("b"), "Float64").alias("b_as_float"), + f.arrow_cast(column("b"), "Int32").alias("b_as_int"), ) result = df.collect() assert len(result) == 1 @@ -1673,20 +1670,35 @@ def test_get_field(df): ), ) result = df.select( - f.get_field(column("s"), string_literal("x")).alias("x_val"), - f.get_field(column("s"), string_literal("y")).alias("y_val"), + f.get_field(column("s"), "x").alias("x_val"), + f.get_field(column("s"), "y").alias("y_val"), ).collect()[0] assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view()) assert result.column(1) == pa.array([4, 5, 6]) -def test_arrow_metadata(df): +def test_arrow_metadata(): + ctx = SessionContext() + field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"}) + schema = pa.schema([field]) + batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema) + df = ctx.create_dataframe([[batch]]) + + # One-argument form: returns a Map of all metadata key-value pairs result = df.select( - f.arrow_metadata(column("a")).alias("meta"), + f.arrow_metadata(column("val")).alias("meta"), ).collect()[0] - # The metadata column should be returned as a map type (possibly empty) assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8()) + meta = result.column(0)[0].as_py() + assert ("key1", "value1") in meta + assert ("key2", "value2") in meta + + # Two-argument form: returns the value for a specific metadata key + result = df.select( + f.arrow_metadata(column("val"), string_literal("key1")).alias("meta_val"), + ).collect()[0] + assert result.column(0)[0].as_py() == "value1" def test_version(): @@ -1726,7 +1738,5 @@ def test_union_extract(): arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) - result = df.select( - f.union_extract(column("u"), string_literal("int")).alias("val") - ).collect()[0] + result = df.select(f.union_extract(column("u"), "int").alias("val")).collect()[0] assert result.column(0).to_pylist() == [1, None, 2] From df1ead15a84cf0e3f1795ed3b422a1c41f4a5720 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 3 Apr 2026 15:42:15 -0400 Subject: [PATCH 4/9] Accept str for key parameter in arrow_metadata for consistency Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 11 ++++++++--- python/tests/test_functions.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 7ed8d37cc..bf5ac66ba 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2619,19 +2619,24 @@ def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr: return Expr(f.arrow_cast(expr.expr, data_type.expr)) -def arrow_metadata(*args: Expr) -> Expr: +def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: """Returns the metadata of the input expression. If called with one argument, returns a Map of all metadata key-value pairs. If called with two arguments, returns the value for the specified metadata key. Args: - args: An expression, optionally followed by a metadata key string. + expr: An expression whose metadata to retrieve. + key: Optional metadata key to look up. Can be a string or an Expr. Returns: A Map of metadata or a specific metadata value. """ - return Expr(f.arrow_metadata(*[arg.expr for arg in args])) + if key is None: + return Expr(f.arrow_metadata(expr.expr)) + if isinstance(key, str): + key = Expr.string_literal(key) + return Expr(f.arrow_metadata(expr.expr, key.expr)) def get_field(expr: Expr, name: Expr | str) -> Expr: diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 3e22ca7a2..d3faf8f26 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -20,7 +20,7 @@ import numpy as np import pyarrow as pa import pytest -from datafusion import SessionContext, column, literal, string_literal +from datafusion import SessionContext, column, literal from datafusion import functions as f np.seterr(invalid="ignore") @@ -1696,7 +1696,7 @@ def test_arrow_metadata(): # Two-argument form: returns the value for a specific metadata key result = df.select( - f.arrow_metadata(column("val"), string_literal("key1")).alias("meta_val"), + f.arrow_metadata(column("val"), "key1").alias("meta_val"), ).collect()[0] assert result.column(0)[0].as_py() == "value1" From a662e18fb6f4ec514bd2d95f64f41c9d9c861a33 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Apr 2026 08:02:46 -0400 Subject: [PATCH 5/9] Add doctest examples and fix docstring style for new scalar functions Replace Args/Returns sections with doctest Examples blocks for arrow_metadata, get_field, union_extract, union_tag, and version to match existing codebase conventions. Simplify row to alias-style docstring with See Also reference. Document that arrow_cast accepts both str and Expr for data_type. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 106 +++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 31 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index b313c6e68..d16d69960 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2637,6 +2637,8 @@ def arrow_typeof(arg: Expr) -> Expr: def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr: """Casts an expression to a specified data type. + The ``data_type`` can be a string or an ``Expr``. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) @@ -2657,12 +2659,26 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: If called with one argument, returns a Map of all metadata key-value pairs. If called with two arguments, returns the value for the specified metadata key. - Args: - expr: An expression whose metadata to retrieve. - key: Optional metadata key to look up. Can be a string or an Expr. + Examples: + >>> import pyarrow as pa + >>> field = pa.field("val", pa.int64(), metadata={"k": "v"}) + >>> schema = pa.schema([field]) + >>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema) + >>> ctx = dfn.SessionContext() + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.arrow_metadata(dfn.col("val")).alias("meta") + ... ) + >>> ("k", "v") in result.collect_column("meta")[0].as_py() + True - Returns: - A Map of metadata or a specific metadata value. + >>> result = df.select( + ... dfn.functions.arrow_metadata( + ... dfn.col("val"), key="k" + ... ).alias("meta_val") + ... ) + >>> result.collect_column("meta_val")[0].as_py() + 'v' """ if key is None: return Expr(f.arrow_metadata(expr.expr)) @@ -2674,12 +2690,20 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: def get_field(expr: Expr, name: Expr | str) -> Expr: """Extracts a field from a struct or map by name. - Args: - expr: A struct or map expression. - name: The field name to extract. - - Returns: - The value of the named field. + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1], "b": [2]}) + >>> df = df.with_column( + ... "s", + ... dfn.functions.named_struct( + ... [("x", dfn.col("a")), ("y", dfn.col("b"))] + ... ), + ... ) + >>> result = df.select( + ... dfn.functions.get_field(dfn.col("s"), "x").alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 """ if isinstance(name, str): name = Expr.string_literal(name) @@ -2692,12 +2716,22 @@ def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: Returns the value of the named field if it is the currently selected variant, otherwise returns NULL. - Args: - union_expr: A union-typed expression. - field_name: The name of the field to extract. - - Returns: - The extracted value or NULL. + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_extract(dfn.col("u"), "int").alias("val") + ... ) + >>> result.collect_column("val").to_pylist() + [1, None, 2] """ if isinstance(field_name, str): field_name = Expr.string_literal(field_name) @@ -2707,11 +2741,22 @@ def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: def union_tag(union_expr: Expr) -> Expr: """Returns the tag (active field name) of a union type. - Args: - union_expr: A union-typed expression. - - Returns: - The name of the currently selected field in the union. + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_tag(dfn.col("u")).alias("tag") + ... ) + >>> result.collect_column("tag").to_pylist() + ['int', 'str', 'int'] """ return Expr(f.union_tag(union_expr.expr)) @@ -2719,8 +2764,12 @@ def union_tag(union_expr: Expr) -> Expr: def version() -> Expr: """Returns the DataFusion version string. - Returns: - A string describing the DataFusion version. + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select(dfn.functions.version().alias("v")) + >>> "Apache DataFusion" in result.collect_column("v")[0].as_py() + True """ return Expr(f.version()) @@ -2728,13 +2777,8 @@ def version() -> Expr: def row(*args: Expr) -> Expr: """Returns a struct with the given arguments. - This is an alias for :py:func:`struct`. - - Args: - args: The expressions to include in the struct. - - Returns: - A struct expression. + See Also: + This is an alias for :py:func:`struct`. """ return struct(*args) From b627d3038681fad28c764a99e6b4bad9961c9a9d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Apr 2026 08:11:32 -0400 Subject: [PATCH 6/9] Support pyarrow DataType in arrow_cast Allow arrow_cast to accept a pyarrow DataType in addition to str and Expr. The DataType is converted to its string representation before being passed to DataFusion. Adds test coverage for the new input type. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 15 +++++++++++++-- python/tests/test_functions.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index d16d69960..113f043db 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2634,10 +2634,10 @@ def arrow_typeof(arg: Expr) -> Expr: return Expr(f.arrow_typeof(arg.expr)) -def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr: +def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: """Casts an expression to a specified data type. - The ``data_type`` can be a string or an ``Expr``. + The ``data_type`` can be a string, a ``pyarrow.DataType``, or an ``Expr``. Examples: >>> ctx = dfn.SessionContext() @@ -2647,7 +2647,18 @@ def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr: ... ) >>> result.collect_column("c")[0].as_py() 1.0 + + >>> import pyarrow as pa + >>> result = df.select( + ... dfn.functions.arrow_cast( + ... dfn.col("a"), data_type=pa.float64() + ... ).alias("c") + ... ) + >>> result.collect_column("c")[0].as_py() + 1.0 """ + if isinstance(data_type, pa.DataType): + data_type = str(data_type) if isinstance(data_type, str): data_type = Expr.string_literal(data_type) return Expr(f.arrow_cast(expr.expr, data_type.expr)) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 25ade5be4..4e99fa9e3 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1302,6 +1302,19 @@ def test_arrow_cast(df): assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) +def test_arrow_cast_with_pyarrow_type(df): + df = df.select( + f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"), + f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"), + f.arrow_cast(column("b"), pa.string()).alias("b_as_str"), + ) + result = df.collect()[0] + + assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64()) + assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) + assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string()) + + def test_case(df): df = df.select( f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)), From f760e7094db48c59db53ec2bc2435ecd2aa97b4c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Apr 2026 08:14:55 -0400 Subject: [PATCH 7/9] Document bracket syntax shorthand in get_field docstring Note that expr["field"] is a convenient alternative when the field name is a static string, and get_field is needed for dynamic expressions. Add a second doctest example showing the bracket syntax. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 113f043db..d1c3cd0f3 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2701,6 +2701,10 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: def get_field(expr: Expr, name: Expr | str) -> Expr: """Extracts a field from a struct or map by name. + When the field name is a static string, the bracket operator + ``expr["field"]`` is a convenient shorthand. Use ``get_field`` + when the field name is a dynamic expression. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1], "b": [2]}) @@ -2715,6 +2719,14 @@ def get_field(expr: Expr, name: Expr | str) -> Expr: ... ) >>> result.collect_column("x_val")[0].as_py() 1 + + Equivalent using bracket syntax: + + >>> result = df.select( + ... dfn.col("s")["x"].alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 """ if isinstance(name, str): name = Expr.string_literal(name) From d12f72192462d9254f1de9fe34791839956b8875 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Apr 2026 08:27:31 -0400 Subject: [PATCH 8/9] Fix arrow_cast with pyarrow DataType by delegating to Expr.cast Use the existing Rust-side PyArrowType conversion via Expr.cast() instead of str() which produces pyarrow type names that DataFusion does not recognize. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index d1c3cd0f3..479ece680 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2658,7 +2658,7 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: 1.0 """ if isinstance(data_type, pa.DataType): - data_type = str(data_type) + return expr.cast(data_type) if isinstance(data_type, str): data_type = Expr.string_literal(data_type) return Expr(f.arrow_cast(expr.expr, data_type.expr)) @@ -2789,7 +2789,7 @@ def version() -> Expr: Examples: >>> ctx = dfn.SessionContext() - >>> df = ctx.from_pydict({"a": [1]}) + >>> df = ctx.empty_table() >>> result = df.select(dfn.functions.version().alias("v")) >>> "Apache DataFusion" in result.collect_column("v")[0].as_py() True From 056f712adfaae0dd159b48b2327abf989c83a651 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Apr 2026 08:29:29 -0400 Subject: [PATCH 9/9] Clarify when to use arrow_cast vs Expr.cast in docstring Co-Authored-By: Claude Opus 4.6 (1M context) --- python/datafusion/functions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 479ece680..aa7f28746 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2637,7 +2637,13 @@ def arrow_typeof(arg: Expr) -> Expr: def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: """Casts an expression to a specified data type. - The ``data_type`` can be a string, a ``pyarrow.DataType``, or an ``Expr``. + The ``data_type`` can be a string, a ``pyarrow.DataType``, or an + ``Expr``. For simple types, :py:meth:`Expr.cast() + ` is more concise + (e.g., ``col("a").cast(pa.float64())``). Use ``arrow_cast`` when + you want to specify the target type as a string using DataFusion's + type syntax, which can be more readable for complex types like + ``"Timestamp(Nanosecond, None)"``. Examples: >>> ctx = dfn.SessionContext()