Skip to content

Commit 8a35cae

Browse files
timsaucerclaudenuno-faria
authored
Add missing map functions (#1461)
* Add map functions (make_map, map_keys, map_values, map_extract, map_entries, element_at) Closes #1448 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add unit tests for map functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove redundant pyo3 element_at function element_at is already a Python-only alias for map_extract, so the Rust binding is unnecessary. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Change make_map to accept a Python dictionary make_map now takes a dict for the common case and also supports separate keys/values lists for column expressions. Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Make map the primary function with make_map as alias map() now supports three calling conventions matching upstream: - map({"a": 1, "b": 2}) — from a Python dictionary - map([keys], [values]) — two lists that get zipped - map(k1, v1, k2, v2, ...) — variadic key-value pairs Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Improve map function docstrings - Add examples for all three map() calling conventions - Use clearer descriptions instead of jargon (no "zipped" or "variadic") - Break map_keys/map_values/map_extract/map_entries examples into two steps: create the map column first, then call the function Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove map() in favor of make_map(), fix docstrings, add validation - Remove map() function that shadowed Python builtin; make_map() is now the sole entry point for creating map expressions - Fix map_extract/element_at docstrings: missing keys return [None], not an empty list (matches actual upstream behavior) - Add length validation for the two-list calling convention - Update all tests and docstring examples accordingly Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Consolidate map function tests into parametrized groups Reduce boilerplate by combining make_map construction tests and map accessor function tests into two @pytest.mark.parametrize groups. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Docstring update Co-authored-by: Nuno Faria <nunofpfaria@gmail.com> * Docstring update Co-authored-by: Nuno Faria <nunofpfaria@gmail.com> * Simplify test for readability Co-authored-by: Nuno Faria <nunofpfaria@gmail.com> * Simplify test for readability Co-authored-by: Nuno Faria <nunofpfaria@gmail.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Nuno Faria <nunofpfaria@gmail.com>
1 parent 16feeb1 commit 8a35cae

File tree

3 files changed

+278
-0
lines changed

3 files changed

+278
-0
lines changed

crates/core/src/functions.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ fn array_cat(exprs: Vec<PyExpr>) -> PyExpr {
9393
array_concat(exprs)
9494
}
9595

96+
#[pyfunction]
97+
fn make_map(keys: Vec<PyExpr>, values: Vec<PyExpr>) -> PyExpr {
98+
let keys = keys.into_iter().map(|x| x.into()).collect();
99+
let values = values.into_iter().map(|x| x.into()).collect();
100+
datafusion::functions_nested::map::map(keys, values).into()
101+
}
102+
96103
#[pyfunction]
97104
#[pyo3(signature = (array, element, index=None))]
98105
fn array_position(array: PyExpr, element: PyExpr, index: Option<i64>) -> PyExpr {
@@ -678,6 +685,12 @@ array_fn!(cardinality, array);
678685
array_fn!(flatten, array);
679686
array_fn!(range, start stop step);
680687

688+
// Map Functions
689+
array_fn!(map_keys, map);
690+
array_fn!(map_values, map);
691+
array_fn!(map_extract, map key);
692+
array_fn!(map_entries, map);
693+
681694
aggregate_function!(array_agg);
682695
aggregate_function!(max);
683696
aggregate_function!(min);
@@ -1142,6 +1155,13 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
11421155
m.add_wrapped(wrap_pyfunction!(flatten))?;
11431156
m.add_wrapped(wrap_pyfunction!(cardinality))?;
11441157

1158+
// Map Functions
1159+
m.add_wrapped(wrap_pyfunction!(make_map))?;
1160+
m.add_wrapped(wrap_pyfunction!(map_keys))?;
1161+
m.add_wrapped(wrap_pyfunction!(map_values))?;
1162+
m.add_wrapped(wrap_pyfunction!(map_extract))?;
1163+
m.add_wrapped(wrap_pyfunction!(map_entries))?;
1164+
11451165
// Window Functions
11461166
m.add_wrapped(wrap_pyfunction!(lead))?;
11471167
m.add_wrapped(wrap_pyfunction!(lag))?;

python/datafusion/functions.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
"degrees",
141141
"dense_rank",
142142
"digest",
143+
"element_at",
143144
"empty",
144145
"encode",
145146
"ends_with",
@@ -206,7 +207,12 @@
206207
"make_array",
207208
"make_date",
208209
"make_list",
210+
"make_map",
209211
"make_time",
212+
"map_entries",
213+
"map_extract",
214+
"map_keys",
215+
"map_values",
210216
"max",
211217
"md5",
212218
"mean",
@@ -3458,6 +3464,158 @@ def empty(array: Expr) -> Expr:
34583464
return array_empty(array)
34593465

34603466

3467+
# map functions
3468+
3469+
3470+
def make_map(*args: Any) -> Expr:
3471+
"""Returns a map expression.
3472+
3473+
Supports three calling conventions:
3474+
3475+
- ``make_map({"a": 1, "b": 2})`` — from a Python dictionary.
3476+
- ``make_map([keys], [values])`` — from a list of keys and a list of
3477+
their associated values. Both lists must be the same length.
3478+
- ``make_map(k1, v1, k2, v2, ...)`` — from alternating keys and their
3479+
associated values.
3480+
3481+
Keys and values that are not already :py:class:`~datafusion.expr.Expr`
3482+
are automatically converted to literal expressions.
3483+
3484+
Examples:
3485+
From a dictionary:
3486+
3487+
>>> ctx = dfn.SessionContext()
3488+
>>> df = ctx.from_pydict({"a": [1]})
3489+
>>> result = df.select(
3490+
... dfn.functions.make_map({"a": 1, "b": 2}).alias("m"))
3491+
>>> result.collect_column("m")[0].as_py()
3492+
[('a', 1), ('b', 2)]
3493+
3494+
From two lists:
3495+
3496+
>>> df = ctx.from_pydict({"key": ["x", "y"], "val": [10, 20]})
3497+
>>> df = df.select(
3498+
... dfn.functions.make_map(
3499+
... [dfn.col("key")], [dfn.col("val")]
3500+
... ).alias("m"))
3501+
>>> df.collect_column("m")[0].as_py()
3502+
[('x', 10)]
3503+
3504+
From alternating keys and values:
3505+
3506+
>>> df = ctx.from_pydict({"a": [1]})
3507+
>>> result = df.select(
3508+
... dfn.functions.make_map("x", 1, "y", 2).alias("m"))
3509+
>>> result.collect_column("m")[0].as_py()
3510+
[('x', 1), ('y', 2)]
3511+
"""
3512+
if len(args) == 1 and isinstance(args[0], dict):
3513+
key_list = list(args[0].keys())
3514+
value_list = list(args[0].values())
3515+
elif (
3516+
len(args) == 2 # noqa: PLR2004
3517+
and isinstance(args[0], list)
3518+
and isinstance(args[1], list)
3519+
):
3520+
if len(args[0]) != len(args[1]):
3521+
msg = "make_map requires key and value lists to be the same length"
3522+
raise ValueError(msg)
3523+
key_list = args[0]
3524+
value_list = args[1]
3525+
elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004
3526+
key_list = list(args[0::2])
3527+
value_list = list(args[1::2])
3528+
else:
3529+
msg = (
3530+
"make_map expects a dict, two lists, or an even number of "
3531+
"key-value arguments"
3532+
)
3533+
raise ValueError(msg)
3534+
3535+
key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list]
3536+
val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list]
3537+
return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs]))
3538+
3539+
3540+
def map_keys(map: Expr) -> Expr:
3541+
"""Returns a list of all keys in the map.
3542+
3543+
Examples:
3544+
>>> ctx = dfn.SessionContext()
3545+
>>> df = ctx.from_pydict({"a": [1]})
3546+
>>> df = df.select(
3547+
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
3548+
>>> result = df.select(
3549+
... dfn.functions.map_keys(dfn.col("m")).alias("keys"))
3550+
>>> result.collect_column("keys")[0].as_py()
3551+
['x', 'y']
3552+
"""
3553+
return Expr(f.map_keys(map.expr))
3554+
3555+
3556+
def map_values(map: Expr) -> Expr:
3557+
"""Returns a list of all values in the map.
3558+
3559+
Examples:
3560+
>>> ctx = dfn.SessionContext()
3561+
>>> df = ctx.from_pydict({"a": [1]})
3562+
>>> df = df.select(
3563+
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
3564+
>>> result = df.select(
3565+
... dfn.functions.map_values(dfn.col("m")).alias("vals"))
3566+
>>> result.collect_column("vals")[0].as_py()
3567+
[1, 2]
3568+
"""
3569+
return Expr(f.map_values(map.expr))
3570+
3571+
3572+
def map_extract(map: Expr, key: Expr) -> Expr:
3573+
"""Returns the value for a given key in the map.
3574+
3575+
Returns ``[None]`` if the key is absent.
3576+
3577+
Examples:
3578+
>>> ctx = dfn.SessionContext()
3579+
>>> df = ctx.from_pydict({"a": [1]})
3580+
>>> df = df.select(
3581+
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
3582+
>>> result = df.select(
3583+
... dfn.functions.map_extract(
3584+
... dfn.col("m"), dfn.lit("x")
3585+
... ).alias("val"))
3586+
>>> result.collect_column("val")[0].as_py()
3587+
[1]
3588+
"""
3589+
return Expr(f.map_extract(map.expr, key.expr))
3590+
3591+
3592+
def map_entries(map: Expr) -> Expr:
3593+
"""Returns a list of all entries (key-value struct pairs) in the map.
3594+
3595+
Examples:
3596+
>>> ctx = dfn.SessionContext()
3597+
>>> df = ctx.from_pydict({"a": [1]})
3598+
>>> df = df.select(
3599+
... dfn.functions.make_map({"x": 1, "y": 2}).alias("m"))
3600+
>>> result = df.select(
3601+
... dfn.functions.map_entries(dfn.col("m")).alias("entries"))
3602+
>>> result.collect_column("entries")[0].as_py()
3603+
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]
3604+
"""
3605+
return Expr(f.map_entries(map.expr))
3606+
3607+
3608+
def element_at(map: Expr, key: Expr) -> Expr:
3609+
"""Returns the value for a given key in the map.
3610+
3611+
Returns ``[None]`` if the key is absent.
3612+
3613+
See Also:
3614+
This is an alias for :py:func:`map_extract`.
3615+
"""
3616+
return map_extract(map, key)
3617+
3618+
34613619
# aggregate functions
34623620
def approx_distinct(
34633621
expression: Expr,

python/tests/test_functions.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,106 @@ def test_array_function_obj_tests(stmt, py_expr):
668668
assert a == b
669669

670670

671+
@pytest.mark.parametrize(
672+
("args", "expected"),
673+
[
674+
pytest.param(
675+
({"x": 1, "y": 2},),
676+
[("x", 1), ("y", 2)],
677+
id="dict",
678+
),
679+
pytest.param(
680+
({"x": literal(1), "y": literal(2)},),
681+
[("x", 1), ("y", 2)],
682+
id="dict_with_exprs",
683+
),
684+
pytest.param(
685+
("x", 1, "y", 2),
686+
[("x", 1), ("y", 2)],
687+
id="variadic_pairs",
688+
),
689+
pytest.param(
690+
(literal("x"), literal(1), literal("y"), literal(2)),
691+
[("x", 1), ("y", 2)],
692+
id="variadic_with_exprs",
693+
),
694+
],
695+
)
696+
def test_make_map(args, expected):
697+
ctx = SessionContext()
698+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
699+
df = ctx.create_dataframe([[batch]])
700+
701+
result = df.select(f.make_map(*args).alias("m")).collect()[0].column(0)
702+
assert result[0].as_py() == expected
703+
704+
705+
def test_make_map_from_two_lists():
706+
ctx = SessionContext()
707+
batch = pa.RecordBatch.from_arrays(
708+
[
709+
pa.array(["k1", "k2", "k3"]),
710+
pa.array([10, 20, 30]),
711+
],
712+
names=["keys", "vals"],
713+
)
714+
df = ctx.create_dataframe([[batch]])
715+
716+
m = f.make_map([column("keys")], [column("vals")])
717+
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
718+
assert result.to_pylist() == [["k1"], ["k2"], ["k3"]]
719+
720+
result = df.select(f.map_values(m).alias("v")).collect()[0].column(0)
721+
assert result.to_pylist() == [[10], [20], [30]]
722+
723+
724+
def test_make_map_odd_args_raises():
725+
with pytest.raises(ValueError, match="make_map expects"):
726+
f.make_map("x", 1, "y")
727+
728+
729+
def test_make_map_mismatched_lengths():
730+
with pytest.raises(ValueError, match="same length"):
731+
f.make_map(["a", "b"], [1])
732+
733+
734+
@pytest.mark.parametrize(
735+
("func", "expected"),
736+
[
737+
pytest.param(f.map_keys, ["x", "y"], id="map_keys"),
738+
pytest.param(f.map_values, [1, 2], id="map_values"),
739+
pytest.param(
740+
lambda m: f.map_extract(m, literal("x")),
741+
[1],
742+
id="map_extract",
743+
),
744+
pytest.param(
745+
lambda m: f.map_extract(m, literal("z")),
746+
[None],
747+
id="map_extract_missing_key",
748+
),
749+
pytest.param(
750+
f.map_entries,
751+
[{"key": "x", "value": 1}, {"key": "y", "value": 2}],
752+
id="map_entries",
753+
),
754+
pytest.param(
755+
lambda m: f.element_at(m, literal("y")),
756+
[2],
757+
id="element_at",
758+
),
759+
],
760+
)
761+
def test_map_functions(func, expected):
762+
ctx = SessionContext()
763+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
764+
df = ctx.create_dataframe([[batch]])
765+
766+
m = f.make_map({"x": 1, "y": 2})
767+
result = df.select(func(m).alias("out")).collect()[0].column(0)
768+
assert result[0].as_py() == expected
769+
770+
671771
@pytest.mark.parametrize(
672772
("function", "expected_result"),
673773
[

0 commit comments

Comments
 (0)