@@ -3571,43 +3571,52 @@ def test_read_parquet_file_sort_order(tmp_path, file_sort_order):
35713571 assert df .collect ()[0 ].column (0 ).to_pylist () == [1 , 2 ]
35723572
35733573
3574- def test_except_distinct ():
3574+ @pytest .mark .parametrize (
3575+ ("df1_data" , "df2_data" , "method" , "expected_a" , "expected_b" ),
3576+ [
3577+ pytest .param (
3578+ {"a" : [1 , 2 , 3 , 1 ], "b" : [10 , 20 , 30 , 10 ]},
3579+ {"a" : [1 , 2 ], "b" : [10 , 20 ]},
3580+ "except_distinct" ,
3581+ [3 ],
3582+ [30 ],
3583+ id = "except_distinct: removes matching rows and deduplicates" ,
3584+ ),
3585+ pytest .param (
3586+ {"a" : [1 , 2 , 3 , 1 ], "b" : [10 , 20 , 30 , 10 ]},
3587+ {"a" : [1 , 4 ], "b" : [10 , 40 ]},
3588+ "intersect_distinct" ,
3589+ [1 ],
3590+ [10 ],
3591+ id = "intersect_distinct: keeps common rows and deduplicates" ,
3592+ ),
3593+ pytest .param (
3594+ {"a" : [1 ], "b" : [10 ]},
3595+ {"b" : [20 ], "a" : [2 ]}, # reversed column order tests matching by name
3596+ "union_by_name" ,
3597+ [1 , 2 ],
3598+ [10 , 20 ],
3599+ id = "union_by_name: matches columns by name not position" ,
3600+ ),
3601+ pytest .param (
3602+ {"a" : [1 , 1 ], "b" : [10 , 10 ]},
3603+ {"b" : [10 ], "a" : [1 ]}, # reversed column order with duplicates
3604+ "union_by_name_distinct" ,
3605+ [1 ],
3606+ [10 ],
3607+ id = "union_by_name_distinct: matches by name and deduplicates" ,
3608+ ),
3609+ ],
3610+ )
3611+ def test_set_operations_distinct (df1_data , df2_data , method , expected_a , expected_b ):
35753612 ctx = SessionContext ()
3576- df1 = ctx .from_pydict ({ "a" : [ 1 , 2 , 3 , 1 ], "b" : [ 10 , 20 , 30 , 10 ]} )
3577- df2 = ctx .from_pydict ({ "a" : [ 1 , 2 ], "b" : [ 10 , 20 ]} )
3613+ df1 = ctx .from_pydict (df1_data )
3614+ df2 = ctx .from_pydict (df2_data )
35783615 result = (
3579- df1 . except_distinct (df2 ).sort (column ("a" ).sort (ascending = True )).collect ()[0 ]
3616+ getattr ( df1 , method ) (df2 ).sort (column ("a" ).sort (ascending = True )).collect ()[0 ]
35803617 )
3581- assert result .column (0 ).to_pylist () == [3 ]
3582- assert result .column (1 ).to_pylist () == [30 ]
3583-
3584-
3585- def test_intersect_distinct ():
3586- ctx = SessionContext ()
3587- df1 = ctx .from_pydict ({"a" : [1 , 2 , 3 , 1 ], "b" : [10 , 20 , 30 , 10 ]})
3588- df2 = ctx .from_pydict ({"a" : [1 , 4 ], "b" : [10 , 40 ]})
3589- result = df1 .intersect_distinct (df2 ).collect ()[0 ]
3590- assert result .column (0 ).to_pylist () == [1 ]
3591- assert result .column (1 ).to_pylist () == [10 ]
3592-
3593-
3594- def test_union_by_name ():
3595- ctx = SessionContext ()
3596- df1 = ctx .from_pydict ({"a" : [1 ], "b" : [10 ]})
3597- # Different column order
3598- df2 = ctx .from_pydict ({"b" : [20 ], "a" : [2 ]})
3599- batches = df1 .union_by_name (df2 ).sort (column ("a" ).sort (ascending = True )).collect ()
3600- rows = pa .concat_arrays ([b .column (0 ) for b in batches ]).to_pylist ()
3601- assert rows == [1 , 2 ]
3602-
3603-
3604- def test_union_by_name_distinct ():
3605- ctx = SessionContext ()
3606- df1 = ctx .from_pydict ({"a" : [1 , 1 ], "b" : [10 , 10 ]})
3607- df2 = ctx .from_pydict ({"b" : [10 ], "a" : [1 ]})
3608- batches = df1 .union_by_name_distinct (df2 ).collect ()
3609- total_rows = sum (b .num_rows for b in batches )
3610- assert total_rows == 1
3618+ assert result .column (0 ).to_pylist () == expected_a
3619+ assert result .column (1 ).to_pylist () == expected_b
36113620
36123621
36133622def test_distinct_on ():
@@ -3627,50 +3636,38 @@ def test_distinct_on():
36273636 assert result .column (1 ).to_pylist () == [10 , 30 ]
36283637
36293638
3630- def test_sort_by ():
3631- ctx = SessionContext ()
3632- df = ctx .from_pydict ({"a" : [3 , 1 , 2 ]})
3633- result = df .sort_by (column ("a" )).collect ()[0 ]
3634- # sort_by always sorts ascending with nulls last
3635- assert result .column (0 ).to_pylist () == [1 , 2 , 3 ]
3636-
3637-
3638- def test_sort_by_is_always_ascending ():
3639- """Verify sort_by uses ascending order regardless of input order."""
3639+ @pytest .mark .parametrize (
3640+ "input_values" ,
3641+ [
3642+ [3 , 1 , 2 ],
3643+ [1 , 2 , 3 ],
3644+ ],
3645+ )
3646+ def test_sort_by (input_values ):
3647+ """sort_by always sorts ascending with nulls last regardless of input order."""
36403648 ctx = SessionContext ()
3641- df = ctx .from_pydict ({"a" : [ 1 , 2 , 3 ] })
3649+ df = ctx .from_pydict ({"a" : input_values })
36423650 result = df .sort_by (column ("a" )).collect ()[0 ]
36433651 assert result .column (0 ).to_pylist () == [1 , 2 , 3 ]
36443652
36453653
3646- def test_explain_with_format (capsys ):
3654+ @pytest .mark .parametrize (
3655+ ("fmt" , "verbose" , "analyze" ),
3656+ [
3657+ (None , False , False ),
3658+ ("TREE" , False , False ),
3659+ ("INDENT" , True , True ),
3660+ ("PGJSON" , False , False ),
3661+ ("GRAPHVIZ" , False , False ),
3662+ ],
3663+ )
3664+ def test_explain_with_format (capsys , fmt , verbose , analyze ):
36473665 from datafusion import ExplainFormat
36483666
36493667 ctx = SessionContext ()
36503668 df = ctx .from_pydict ({"a" : [1 ]})
3651-
3652- # Default format works
3653- df .explain ()
3654- captured = capsys .readouterr ()
3655- assert "plan_type" in captured .out
3656-
3657- # Tree format produces box-drawing characters
3658- df .explain (format = ExplainFormat .TREE )
3659- captured = capsys .readouterr ()
3660- assert "\u250c " in captured .out or "plan_type" in captured .out
3661-
3662- # Verbose + analyze still works with format
3663- df .explain (verbose = True , analyze = True , format = ExplainFormat .INDENT )
3664- captured = capsys .readouterr ()
3665- assert "plan_type" in captured .out
3666-
3667- # PGJSON format produces valid output
3668- df .explain (format = ExplainFormat .PGJSON )
3669- captured = capsys .readouterr ()
3670- assert "plan_type" in captured .out
3671-
3672- # Graphviz format produces DOT output
3673- df .explain (format = ExplainFormat .GRAPHVIZ )
3669+ explain_fmt = ExplainFormat [fmt ] if fmt is not None else None
3670+ df .explain (verbose = verbose , analyze = analyze , format = explain_fmt )
36743671 captured = capsys .readouterr ()
36753672 assert "plan_type" in captured .out
36763673
0 commit comments