Skip to content

Commit bb18bf9

Browse files
timsaucerclaude
andcommitted
Consolidate new DataFrame tests into parametrized tests
Combine set operation tests (except_distinct, intersect_distinct, union_by_name, union_by_name_distinct) into a single parametrized test_set_operations_distinct. Merge sort_by tests and convert explain format tests to parametrized form. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ddc918d commit bb18bf9

File tree

1 file changed

+66
-69
lines changed

1 file changed

+66
-69
lines changed

python/tests/test_dataframe.py

Lines changed: 66 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3571,43 +3571,52 @@ def test_read_parquet_file_sort_order(tmp_path, file_sort_order):
35713571
assert df.collect()[0].column(0).to_pylist() == [1, 2]
35723572

35733573

3574-
def test_except_distinct():
3574+
@pytest.mark.parametrize(
3575+
("df1_data", "df2_data", "method", "expected_a", "expected_b"),
3576+
[
3577+
pytest.param(
3578+
{"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]},
3579+
{"a": [1, 2], "b": [10, 20]},
3580+
"except_distinct",
3581+
[3],
3582+
[30],
3583+
id="except_distinct: removes matching rows and deduplicates",
3584+
),
3585+
pytest.param(
3586+
{"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]},
3587+
{"a": [1, 4], "b": [10, 40]},
3588+
"intersect_distinct",
3589+
[1],
3590+
[10],
3591+
id="intersect_distinct: keeps common rows and deduplicates",
3592+
),
3593+
pytest.param(
3594+
{"a": [1], "b": [10]},
3595+
{"b": [20], "a": [2]}, # reversed column order tests matching by name
3596+
"union_by_name",
3597+
[1, 2],
3598+
[10, 20],
3599+
id="union_by_name: matches columns by name not position",
3600+
),
3601+
pytest.param(
3602+
{"a": [1, 1], "b": [10, 10]},
3603+
{"b": [10], "a": [1]}, # reversed column order with duplicates
3604+
"union_by_name_distinct",
3605+
[1],
3606+
[10],
3607+
id="union_by_name_distinct: matches by name and deduplicates",
3608+
),
3609+
],
3610+
)
3611+
def test_set_operations_distinct(df1_data, df2_data, method, expected_a, expected_b):
35753612
ctx = SessionContext()
3576-
df1 = ctx.from_pydict({"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]})
3577-
df2 = ctx.from_pydict({"a": [1, 2], "b": [10, 20]})
3613+
df1 = ctx.from_pydict(df1_data)
3614+
df2 = ctx.from_pydict(df2_data)
35783615
result = (
3579-
df1.except_distinct(df2).sort(column("a").sort(ascending=True)).collect()[0]
3616+
getattr(df1, method)(df2).sort(column("a").sort(ascending=True)).collect()[0]
35803617
)
3581-
assert result.column(0).to_pylist() == [3]
3582-
assert result.column(1).to_pylist() == [30]
3583-
3584-
3585-
def test_intersect_distinct():
3586-
ctx = SessionContext()
3587-
df1 = ctx.from_pydict({"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]})
3588-
df2 = ctx.from_pydict({"a": [1, 4], "b": [10, 40]})
3589-
result = df1.intersect_distinct(df2).collect()[0]
3590-
assert result.column(0).to_pylist() == [1]
3591-
assert result.column(1).to_pylist() == [10]
3592-
3593-
3594-
def test_union_by_name():
3595-
ctx = SessionContext()
3596-
df1 = ctx.from_pydict({"a": [1], "b": [10]})
3597-
# Different column order
3598-
df2 = ctx.from_pydict({"b": [20], "a": [2]})
3599-
batches = df1.union_by_name(df2).sort(column("a").sort(ascending=True)).collect()
3600-
rows = pa.concat_arrays([b.column(0) for b in batches]).to_pylist()
3601-
assert rows == [1, 2]
3602-
3603-
3604-
def test_union_by_name_distinct():
3605-
ctx = SessionContext()
3606-
df1 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]})
3607-
df2 = ctx.from_pydict({"b": [10], "a": [1]})
3608-
batches = df1.union_by_name_distinct(df2).collect()
3609-
total_rows = sum(b.num_rows for b in batches)
3610-
assert total_rows == 1
3618+
assert result.column(0).to_pylist() == expected_a
3619+
assert result.column(1).to_pylist() == expected_b
36113620

36123621

36133622
def test_distinct_on():
@@ -3627,50 +3636,38 @@ def test_distinct_on():
36273636
assert result.column(1).to_pylist() == [10, 30]
36283637

36293638

3630-
def test_sort_by():
3631-
ctx = SessionContext()
3632-
df = ctx.from_pydict({"a": [3, 1, 2]})
3633-
result = df.sort_by(column("a")).collect()[0]
3634-
# sort_by always sorts ascending with nulls last
3635-
assert result.column(0).to_pylist() == [1, 2, 3]
3636-
3637-
3638-
def test_sort_by_is_always_ascending():
3639-
"""Verify sort_by uses ascending order regardless of input order."""
3639+
@pytest.mark.parametrize(
3640+
"input_values",
3641+
[
3642+
[3, 1, 2],
3643+
[1, 2, 3],
3644+
],
3645+
)
3646+
def test_sort_by(input_values):
3647+
"""sort_by always sorts ascending with nulls last regardless of input order."""
36403648
ctx = SessionContext()
3641-
df = ctx.from_pydict({"a": [1, 2, 3]})
3649+
df = ctx.from_pydict({"a": input_values})
36423650
result = df.sort_by(column("a")).collect()[0]
36433651
assert result.column(0).to_pylist() == [1, 2, 3]
36443652

36453653

3646-
def test_explain_with_format(capsys):
3654+
@pytest.mark.parametrize(
3655+
("fmt", "verbose", "analyze"),
3656+
[
3657+
(None, False, False),
3658+
("TREE", False, False),
3659+
("INDENT", True, True),
3660+
("PGJSON", False, False),
3661+
("GRAPHVIZ", False, False),
3662+
],
3663+
)
3664+
def test_explain_with_format(capsys, fmt, verbose, analyze):
36473665
from datafusion import ExplainFormat
36483666

36493667
ctx = SessionContext()
36503668
df = ctx.from_pydict({"a": [1]})
3651-
3652-
# Default format works
3653-
df.explain()
3654-
captured = capsys.readouterr()
3655-
assert "plan_type" in captured.out
3656-
3657-
# Tree format produces box-drawing characters
3658-
df.explain(format=ExplainFormat.TREE)
3659-
captured = capsys.readouterr()
3660-
assert "\u250c" in captured.out or "plan_type" in captured.out
3661-
3662-
# Verbose + analyze still works with format
3663-
df.explain(verbose=True, analyze=True, format=ExplainFormat.INDENT)
3664-
captured = capsys.readouterr()
3665-
assert "plan_type" in captured.out
3666-
3667-
# PGJSON format produces valid output
3668-
df.explain(format=ExplainFormat.PGJSON)
3669-
captured = capsys.readouterr()
3670-
assert "plan_type" in captured.out
3671-
3672-
# Graphviz format produces DOT output
3673-
df.explain(format=ExplainFormat.GRAPHVIZ)
3669+
explain_fmt = ExplainFormat[fmt] if fmt is not None else None
3670+
df.explain(verbose=verbose, analyze=analyze, format=explain_fmt)
36743671
captured = capsys.readouterr()
36753672
assert "plan_type" in captured.out
36763673

0 commit comments

Comments
 (0)