@@ -668,29 +668,41 @@ def test_array_function_obj_tests(stmt, py_expr):
668668 assert a == b
669669
670670
671- def test_map_from_dict ():
672- ctx = SessionContext ()
673- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
674- df = ctx .create_dataframe ([[batch ]])
675-
676- result = df .select (f .make_map ({"x" : 1 , "y" : 2 }).alias ("m" )).collect ()[0 ].column (0 )
677- assert result [0 ].as_py () == [("x" , 1 ), ("y" , 2 )]
678-
679-
680- def test_map_from_dict_with_expr_values ():
671+ @pytest .mark .parametrize (
672+ ("args" , "expected" ),
673+ [
674+ pytest .param (
675+ ({"x" : 1 , "y" : 2 },),
676+ [("x" , 1 ), ("y" , 2 )],
677+ id = "dict" ,
678+ ),
679+ pytest .param (
680+ ({"x" : literal (1 ), "y" : literal (2 )},),
681+ [("x" , 1 ), ("y" , 2 )],
682+ id = "dict_with_exprs" ,
683+ ),
684+ pytest .param (
685+ ("x" , 1 , "y" , 2 ),
686+ [("x" , 1 ), ("y" , 2 )],
687+ id = "variadic_pairs" ,
688+ ),
689+ pytest .param (
690+ (literal ("x" ), literal (1 ), literal ("y" ), literal (2 )),
691+ [("x" , 1 ), ("y" , 2 )],
692+ id = "variadic_with_exprs" ,
693+ ),
694+ ],
695+ )
696+ def test_make_map (args , expected ):
681697 ctx = SessionContext ()
682698 batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
683699 df = ctx .create_dataframe ([[batch ]])
684700
685- result = (
686- df .select (f .make_map ({"x" : literal (1 ), "y" : literal (2 )}).alias ("m" ))
687- .collect ()[0 ]
688- .column (0 )
689- )
690- assert result [0 ].as_py () == [("x" , 1 ), ("y" , 2 )]
701+ result = df .select (f .make_map (* args ).alias ("m" )).collect ()[0 ].column (0 )
702+ assert result [0 ].as_py () == expected
691703
692704
693- def test_map_from_two_lists ():
705+ def test_make_map_from_two_lists ():
694706 ctx = SessionContext ()
695707 batch = pa .RecordBatch .from_arrays (
696708 [
@@ -711,30 +723,6 @@ def test_map_from_two_lists():
711723 assert result [i ].as_py () == [expected ]
712724
713725
714- def test_map_from_variadic_pairs ():
715- ctx = SessionContext ()
716- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
717- df = ctx .create_dataframe ([[batch ]])
718-
719- result = df .select (f .make_map ("x" , 1 , "y" , 2 ).alias ("m" )).collect ()[0 ].column (0 )
720- assert result [0 ].as_py () == [("x" , 1 ), ("y" , 2 )]
721-
722-
723- def test_map_variadic_with_exprs ():
724- ctx = SessionContext ()
725- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
726- df = ctx .create_dataframe ([[batch ]])
727-
728- result = (
729- df .select (
730- f .make_map (literal ("x" ), literal (1 ), literal ("y" ), literal (2 )).alias ("m" )
731- )
732- .collect ()[0 ]
733- .column (0 )
734- )
735- assert result [0 ].as_py () == [("x" , 1 ), ("y" , 2 )]
736-
737-
738726def test_make_map_odd_args_raises ():
739727 with pytest .raises (ValueError , match = "make_map expects" ):
740728 f .make_map ("x" , 1 , "y" )
@@ -745,73 +733,41 @@ def test_make_map_mismatched_lengths():
745733 f .make_map (["a" , "b" ], [1 ])
746734
747735
748- def test_map_keys ():
749- ctx = SessionContext ()
750- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
751- df = ctx .create_dataframe ([[batch ]])
752-
753- m = f .make_map ({"x" : 1 , "y" : 2 })
754- result = df .select (f .map_keys (m ).alias ("keys" )).collect ()[0 ].column (0 )
755- assert result [0 ].as_py () == ["x" , "y" ]
756-
757-
758- def test_map_values ():
759- ctx = SessionContext ()
760- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
761- df = ctx .create_dataframe ([[batch ]])
762-
763- m = f .make_map ({"x" : 1 , "y" : 2 })
764- result = df .select (f .map_values (m ).alias ("vals" )).collect ()[0 ].column (0 )
765- assert result [0 ].as_py () == [1 , 2 ]
766-
767-
768- def test_map_extract ():
769- ctx = SessionContext ()
770- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
771- df = ctx .create_dataframe ([[batch ]])
772-
773- m = f .make_map ({"x" : 1 , "y" : 2 })
774- result = (
775- df .select (f .map_extract (m , literal ("x" )).alias ("val" )).collect ()[0 ].column (0 )
776- )
777- assert result [0 ].as_py () == [1 ]
778-
779-
780- def test_map_extract_missing_key ():
781- ctx = SessionContext ()
782- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
783- df = ctx .create_dataframe ([[batch ]])
784-
785- m = f .make_map ({"x" : 1 })
786- result = (
787- df .select (f .map_extract (m , literal ("z" )).alias ("val" )).collect ()[0 ].column (0 )
788- )
789- assert result [0 ].as_py () == [None ]
790-
791-
792- def test_map_entries ():
736+ @pytest .mark .parametrize (
737+ ("func" , "expected" ),
738+ [
739+ pytest .param (f .map_keys , ["x" , "y" ], id = "map_keys" ),
740+ pytest .param (f .map_values , [1 , 2 ], id = "map_values" ),
741+ pytest .param (
742+ lambda m : f .map_extract (m , literal ("x" )),
743+ [1 ],
744+ id = "map_extract" ,
745+ ),
746+ pytest .param (
747+ lambda m : f .map_extract (m , literal ("z" )),
748+ [None ],
749+ id = "map_extract_missing_key" ,
750+ ),
751+ pytest .param (
752+ f .map_entries ,
753+ [{"key" : "x" , "value" : 1 }, {"key" : "y" , "value" : 2 }],
754+ id = "map_entries" ,
755+ ),
756+ pytest .param (
757+ lambda m : f .element_at (m , literal ("y" )),
758+ [2 ],
759+ id = "element_at" ,
760+ ),
761+ ],
762+ )
763+ def test_map_functions (func , expected ):
793764 ctx = SessionContext ()
794765 batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
795766 df = ctx .create_dataframe ([[batch ]])
796767
797768 m = f .make_map ({"x" : 1 , "y" : 2 })
798- result = df .select (f .map_entries (m ).alias ("entries" )).collect ()[0 ].column (0 )
799- assert result [0 ].as_py () == [
800- {"key" : "x" , "value" : 1 },
801- {"key" : "y" , "value" : 2 },
802- ]
803-
804-
805- def test_element_at ():
806- ctx = SessionContext ()
807- batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], names = ["a" ])
808- df = ctx .create_dataframe ([[batch ]])
809-
810- m = f .make_map ({"a" : 10 , "b" : 20 })
811- result = (
812- df .select (f .element_at (m , literal ("b" )).alias ("val" )).collect ()[0 ].column (0 )
813- )
814- assert result [0 ].as_py () == [20 ]
769+ result = df .select (func (m ).alias ("out" )).collect ()[0 ].column (0 )
770+ assert result [0 ].as_py () == expected
815771
816772
817773@pytest .mark .parametrize (
0 commit comments