Skip to content

Commit 461acca

Browse files
timsaucerclaude
andcommitted
Consolidate map function tests into parametrized groups
Reduce boilerplate by combining make_map construction tests and map accessor function tests into two @pytest.mark.parametrize groups. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dbe4d2e commit 461acca

File tree

1 file changed

+59
-103
lines changed

1 file changed

+59
-103
lines changed

python/tests/test_functions.py

Lines changed: 59 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -668,29 +668,41 @@ def test_array_function_obj_tests(stmt, py_expr):
668668
assert a == b
669669

670670

671-
def test_map_from_dict():
672-
ctx = SessionContext()
673-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
674-
df = ctx.create_dataframe([[batch]])
675-
676-
result = df.select(f.make_map({"x": 1, "y": 2}).alias("m")).collect()[0].column(0)
677-
assert result[0].as_py() == [("x", 1), ("y", 2)]
678-
679-
680-
def test_map_from_dict_with_expr_values():
671+
@pytest.mark.parametrize(
672+
("args", "expected"),
673+
[
674+
pytest.param(
675+
({"x": 1, "y": 2},),
676+
[("x", 1), ("y", 2)],
677+
id="dict",
678+
),
679+
pytest.param(
680+
({"x": literal(1), "y": literal(2)},),
681+
[("x", 1), ("y", 2)],
682+
id="dict_with_exprs",
683+
),
684+
pytest.param(
685+
("x", 1, "y", 2),
686+
[("x", 1), ("y", 2)],
687+
id="variadic_pairs",
688+
),
689+
pytest.param(
690+
(literal("x"), literal(1), literal("y"), literal(2)),
691+
[("x", 1), ("y", 2)],
692+
id="variadic_with_exprs",
693+
),
694+
],
695+
)
696+
def test_make_map(args, expected):
681697
ctx = SessionContext()
682698
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
683699
df = ctx.create_dataframe([[batch]])
684700

685-
result = (
686-
df.select(f.make_map({"x": literal(1), "y": literal(2)}).alias("m"))
687-
.collect()[0]
688-
.column(0)
689-
)
690-
assert result[0].as_py() == [("x", 1), ("y", 2)]
701+
result = df.select(f.make_map(*args).alias("m")).collect()[0].column(0)
702+
assert result[0].as_py() == expected
691703

692704

693-
def test_map_from_two_lists():
705+
def test_make_map_from_two_lists():
694706
ctx = SessionContext()
695707
batch = pa.RecordBatch.from_arrays(
696708
[
@@ -711,30 +723,6 @@ def test_map_from_two_lists():
711723
assert result[i].as_py() == [expected]
712724

713725

714-
def test_map_from_variadic_pairs():
715-
ctx = SessionContext()
716-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
717-
df = ctx.create_dataframe([[batch]])
718-
719-
result = df.select(f.make_map("x", 1, "y", 2).alias("m")).collect()[0].column(0)
720-
assert result[0].as_py() == [("x", 1), ("y", 2)]
721-
722-
723-
def test_map_variadic_with_exprs():
724-
ctx = SessionContext()
725-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
726-
df = ctx.create_dataframe([[batch]])
727-
728-
result = (
729-
df.select(
730-
f.make_map(literal("x"), literal(1), literal("y"), literal(2)).alias("m")
731-
)
732-
.collect()[0]
733-
.column(0)
734-
)
735-
assert result[0].as_py() == [("x", 1), ("y", 2)]
736-
737-
738726
def test_make_map_odd_args_raises():
739727
with pytest.raises(ValueError, match="make_map expects"):
740728
f.make_map("x", 1, "y")
@@ -745,73 +733,41 @@ def test_make_map_mismatched_lengths():
745733
f.make_map(["a", "b"], [1])
746734

747735

748-
def test_map_keys():
749-
ctx = SessionContext()
750-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
751-
df = ctx.create_dataframe([[batch]])
752-
753-
m = f.make_map({"x": 1, "y": 2})
754-
result = df.select(f.map_keys(m).alias("keys")).collect()[0].column(0)
755-
assert result[0].as_py() == ["x", "y"]
756-
757-
758-
def test_map_values():
759-
ctx = SessionContext()
760-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
761-
df = ctx.create_dataframe([[batch]])
762-
763-
m = f.make_map({"x": 1, "y": 2})
764-
result = df.select(f.map_values(m).alias("vals")).collect()[0].column(0)
765-
assert result[0].as_py() == [1, 2]
766-
767-
768-
def test_map_extract():
769-
ctx = SessionContext()
770-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
771-
df = ctx.create_dataframe([[batch]])
772-
773-
m = f.make_map({"x": 1, "y": 2})
774-
result = (
775-
df.select(f.map_extract(m, literal("x")).alias("val")).collect()[0].column(0)
776-
)
777-
assert result[0].as_py() == [1]
778-
779-
780-
def test_map_extract_missing_key():
781-
ctx = SessionContext()
782-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
783-
df = ctx.create_dataframe([[batch]])
784-
785-
m = f.make_map({"x": 1})
786-
result = (
787-
df.select(f.map_extract(m, literal("z")).alias("val")).collect()[0].column(0)
788-
)
789-
assert result[0].as_py() == [None]
790-
791-
792-
def test_map_entries():
736+
@pytest.mark.parametrize(
737+
("func", "expected"),
738+
[
739+
pytest.param(f.map_keys, ["x", "y"], id="map_keys"),
740+
pytest.param(f.map_values, [1, 2], id="map_values"),
741+
pytest.param(
742+
lambda m: f.map_extract(m, literal("x")),
743+
[1],
744+
id="map_extract",
745+
),
746+
pytest.param(
747+
lambda m: f.map_extract(m, literal("z")),
748+
[None],
749+
id="map_extract_missing_key",
750+
),
751+
pytest.param(
752+
f.map_entries,
753+
[{"key": "x", "value": 1}, {"key": "y", "value": 2}],
754+
id="map_entries",
755+
),
756+
pytest.param(
757+
lambda m: f.element_at(m, literal("y")),
758+
[2],
759+
id="element_at",
760+
),
761+
],
762+
)
763+
def test_map_functions(func, expected):
793764
ctx = SessionContext()
794765
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
795766
df = ctx.create_dataframe([[batch]])
796767

797768
m = f.make_map({"x": 1, "y": 2})
798-
result = df.select(f.map_entries(m).alias("entries")).collect()[0].column(0)
799-
assert result[0].as_py() == [
800-
{"key": "x", "value": 1},
801-
{"key": "y", "value": 2},
802-
]
803-
804-
805-
def test_element_at():
806-
ctx = SessionContext()
807-
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
808-
df = ctx.create_dataframe([[batch]])
809-
810-
m = f.make_map({"a": 10, "b": 20})
811-
result = (
812-
df.select(f.element_at(m, literal("b")).alias("val")).collect()[0].column(0)
813-
)
814-
assert result[0].as_py() == [20]
769+
result = df.select(func(m).alias("out")).collect()[0].column(0)
770+
assert result[0].as_py() == expected
815771

816772

817773
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)