Skip to content

Commit 2771621

Browse files
timsaucerclaude
andcommitted
Accept str for field name and type parameters in scalar functions
Allow arrow_cast, get_field, and union_extract to accept plain str arguments instead of requiring Expr wrappers. Also improve arrow_metadata test coverage and fix parameter shadowing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8485932 commit 2771621

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

python/datafusion/functions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2533,19 +2533,20 @@ def arrow_typeof(arg: Expr) -> Expr:
25332533
return Expr(f.arrow_typeof(arg.expr))
25342534

25352535

2536-
def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
2536+
def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr:
25372537
"""Casts an expression to a specified data type.
25382538
25392539
Examples:
25402540
>>> ctx = dfn.SessionContext()
25412541
>>> df = ctx.from_pydict({"a": [1]})
2542-
>>> data_type = dfn.string_literal("Float64")
25432542
>>> result = df.select(
2544-
... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c")
2543+
... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c")
25452544
... )
25462545
>>> result.collect_column("c")[0].as_py()
25472546
1.0
25482547
"""
2548+
if isinstance(data_type, str):
2549+
data_type = Expr.string_literal(data_type)
25492550
return Expr(f.arrow_cast(expr.expr, data_type.expr))
25502551

25512552

@@ -2561,11 +2562,10 @@ def arrow_metadata(*args: Expr) -> Expr:
25612562
Returns:
25622563
A Map of metadata or a specific metadata value.
25632564
"""
2564-
args = [arg.expr for arg in args]
2565-
return Expr(f.arrow_metadata(*args))
2565+
return Expr(f.arrow_metadata(*[arg.expr for arg in args]))
25662566

25672567

2568-
def get_field(expr: Expr, name: Expr) -> Expr:
2568+
def get_field(expr: Expr, name: Expr | str) -> Expr:
25692569
"""Extracts a field from a struct or map by name.
25702570
25712571
Args:
@@ -2575,10 +2575,12 @@ def get_field(expr: Expr, name: Expr) -> Expr:
25752575
Returns:
25762576
The value of the named field.
25772577
"""
2578+
if isinstance(name, str):
2579+
name = Expr.string_literal(name)
25782580
return Expr(f.get_field(expr.expr, name.expr))
25792581

25802582

2581-
def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
2583+
def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
25822584
"""Extracts a value from a union type by field name.
25832585
25842586
Returns the value of the named field if it is the currently selected
@@ -2591,6 +2593,8 @@ def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
25912593
Returns:
25922594
The extracted value or NULL.
25932595
"""
2596+
if isinstance(field_name, str):
2597+
field_name = Expr.string_literal(field_name)
25942598
return Expr(f.union_extract(union_expr.expr, field_name.expr))
25952599

25962600

python/tests/test_functions.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,11 +1143,8 @@ def test_make_time(df):
11431143

11441144
def test_arrow_cast(df):
11451145
df = df.select(
1146-
# we use `string_literal` to return utf8 instead of `literal` which returns
1147-
# utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view
1148-
# https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
1149-
f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"),
1150-
f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
1146+
f.arrow_cast(column("b"), "Float64").alias("b_as_float"),
1147+
f.arrow_cast(column("b"), "Int32").alias("b_as_int"),
11511148
)
11521149
result = df.collect()
11531150
assert len(result) == 1
@@ -1482,20 +1479,35 @@ def test_get_field(df):
14821479
),
14831480
)
14841481
result = df.select(
1485-
f.get_field(column("s"), string_literal("x")).alias("x_val"),
1486-
f.get_field(column("s"), string_literal("y")).alias("y_val"),
1482+
f.get_field(column("s"), "x").alias("x_val"),
1483+
f.get_field(column("s"), "y").alias("y_val"),
14871484
).collect()[0]
14881485

14891486
assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
14901487
assert result.column(1) == pa.array([4, 5, 6])
14911488

14921489

1493-
def test_arrow_metadata(df):
1490+
def test_arrow_metadata():
1491+
ctx = SessionContext()
1492+
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})
1493+
schema = pa.schema([field])
1494+
batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema)
1495+
df = ctx.create_dataframe([[batch]])
1496+
1497+
# One-argument form: returns a Map of all metadata key-value pairs
14941498
result = df.select(
1495-
f.arrow_metadata(column("a")).alias("meta"),
1499+
f.arrow_metadata(column("val")).alias("meta"),
14961500
).collect()[0]
1497-
# The metadata column should be returned as a map type (possibly empty)
14981501
assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8())
1502+
meta = result.column(0)[0].as_py()
1503+
assert ("key1", "value1") in meta
1504+
assert ("key2", "value2") in meta
1505+
1506+
# Two-argument form: returns the value for a specific metadata key
1507+
result = df.select(
1508+
f.arrow_metadata(column("val"), string_literal("key1")).alias("meta_val"),
1509+
).collect()[0]
1510+
assert result.column(0)[0].as_py() == "value1"
14991511

15001512

15011513
def test_version():
@@ -1535,7 +1547,5 @@ def test_union_extract():
15351547
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
15361548
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])
15371549

1538-
result = df.select(
1539-
f.union_extract(column("u"), string_literal("int")).alias("val")
1540-
).collect()[0]
1550+
result = df.select(f.union_extract(column("u"), "int").alias("val")).collect()[0]
15411551
assert result.column(0).to_pylist() == [1, None, 2]

0 commit comments

Comments
 (0)