Skip to content

Commit 8485932

Browse files
timsaucerclaude
andcommitted
Add tests for new scalar functions
Tests for get_field, arrow_metadata, version, row, union_tag, and union_extract. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d405684 commit 8485932

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

python/tests/test_functions.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,3 +1469,73 @@ def test_coalesce(df):
14691469
assert result.column(0) == pa.array(
14701470
["Hello", "fallback", "!"], type=pa.string_view()
14711471
)
1472+
1473+
1474+
def test_get_field(df):
1475+
df = df.with_column(
1476+
"s",
1477+
f.named_struct(
1478+
[
1479+
("x", column("a")),
1480+
("y", column("b")),
1481+
]
1482+
),
1483+
)
1484+
result = df.select(
1485+
f.get_field(column("s"), string_literal("x")).alias("x_val"),
1486+
f.get_field(column("s"), string_literal("y")).alias("y_val"),
1487+
).collect()[0]
1488+
1489+
assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
1490+
assert result.column(1) == pa.array([4, 5, 6])
1491+
1492+
1493+
def test_arrow_metadata(df):
1494+
result = df.select(
1495+
f.arrow_metadata(column("a")).alias("meta"),
1496+
).collect()[0]
1497+
# The metadata column should be returned as a map type (possibly empty)
1498+
assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8())
1499+
1500+
1501+
def test_version():
1502+
ctx = SessionContext()
1503+
df = ctx.from_pydict({"a": [1]})
1504+
result = df.select(f.version().alias("v")).collect()[0]
1505+
version_str = result.column(0)[0].as_py()
1506+
assert "Apache DataFusion" in version_str
1507+
1508+
1509+
def test_row(df):
1510+
result = df.select(
1511+
f.row(column("a"), column("b")).alias("r"),
1512+
f.struct(column("a"), column("b")).alias("s"),
1513+
).collect()[0]
1514+
# row is an alias for struct, so they should produce the same output
1515+
assert result.column(0) == result.column(1)
1516+
1517+
1518+
def test_union_tag():
1519+
ctx = SessionContext()
1520+
types = pa.array([0, 1, 0], type=pa.int8())
1521+
offsets = pa.array([0, 0, 1], type=pa.int32())
1522+
children = [pa.array([1, 2]), pa.array(["hello"])]
1523+
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
1524+
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])
1525+
1526+
result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0]
1527+
assert result.column(0).to_pylist() == ["int", "str", "int"]
1528+
1529+
1530+
def test_union_extract():
1531+
ctx = SessionContext()
1532+
types = pa.array([0, 1, 0], type=pa.int8())
1533+
offsets = pa.array([0, 0, 1], type=pa.int32())
1534+
children = [pa.array([1, 2]), pa.array(["hello"])]
1535+
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
1536+
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])
1537+
1538+
result = df.select(
1539+
f.union_extract(column("u"), string_literal("int")).alias("val")
1540+
).collect()[0]
1541+
assert result.column(0).to_pylist() == [1, None, 2]

0 commit comments

Comments
 (0)