Skip to content

Commit 4aec22d

Browse files
committed
Expose lit_with_metadata and add unit test
1 parent 8a87136 commit 4aec22d

File tree

4 files changed

+99
-2
lines changed

4 files changed

+99
-2
lines changed

python/datafusion/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
See https://datafusion.apache.org/python for more information.
2222
"""
2323

24+
from __future__ import annotations
25+
26+
from typing import Any
27+
2428
try:
2529
import importlib.metadata as importlib_metadata
2630
except ImportError:
@@ -130,3 +134,18 @@ def str_lit(value):
130134
def lit(value) -> Expr:
131135
"""Create a literal expression."""
132136
return Expr.literal(value)
137+
138+
139+
def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
140+
"""Creates a new expression representing a scalar value with metadata.
141+
142+
Args:
143+
value: A valid PyArrow scalar value or easily castable to one.
144+
metadata: Metadata to attach to the expression.
145+
"""
146+
return Expr.literal_with_metadata(value, metadata)
147+
148+
149+
def lit_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
150+
"""Alias for literal_with_metadata."""
151+
return literal_with_metadata(value, metadata)

python/datafusion/expr.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,20 @@ def literal(value: Any) -> Expr:
435435
value = pa.scalar(value)
436436
return Expr(expr_internal.RawExpr.literal(value))
437437

438+
@staticmethod
439+
def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
440+
"""Creates a new expression representing a scalar value with metadata.
441+
442+
Args:
443+
value: A valid PyArrow scalar value or easily castable to one.
444+
metadata: Metadata to attach to the expression.
445+
"""
446+
if isinstance(value, str):
447+
value = pa.scalar(value, type=pa.string_view())
448+
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)
449+
450+
return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))
451+
438452
@staticmethod
439453
def string_literal(value: str) -> Expr:
440454
"""Creates a new expression representing a UTF8 literal value.

python/tests/test_expr.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919

2020
import pyarrow as pa
2121
import pytest
22-
from datafusion import SessionContext, col, functions, lit
22+
from datafusion import (
23+
SessionContext,
24+
col,
25+
functions,
26+
lit,
27+
lit_with_metadata,
28+
literal_with_metadata,
29+
)
2330
from datafusion.expr import (
2431
Aggregate,
2532
AggregateFunction,
@@ -824,3 +831,52 @@ def test_expr_functions(ctx, function, expected_result):
824831

825832
assert len(result) == 1
826833
assert result[0].column(0).equals(expected_result)
834+
835+
836+
def test_literal_metadata(ctx):
837+
result = (
838+
ctx.from_pydict({"a": [1]})
839+
.select(
840+
lit(1).alias("no_metadata"),
841+
lit_with_metadata(2, {"key1": "value1"}).alias("lit_with_metadata_fn"),
842+
literal_with_metadata(3, {"key2": "value2"}).alias(
843+
"literal_with_metadata_fn"
844+
),
845+
)
846+
.collect()
847+
)
848+
849+
expected_schema = pa.schema(
850+
[
851+
pa.field("no_metadata", pa.int64(), nullable=False),
852+
pa.field(
853+
"lit_with_metadata_fn",
854+
pa.int64(),
855+
nullable=False,
856+
metadata={"key1": "value1"},
857+
),
858+
pa.field(
859+
"literal_with_metadata_fn",
860+
pa.int64(),
861+
nullable=False,
862+
metadata={"key2": "value2"},
863+
),
864+
]
865+
)
866+
867+
expected = pa.RecordBatch.from_pydict(
868+
{
869+
"no_metadata": pa.array([1]),
870+
"lit_with_metadata_fn": pa.array([2]),
871+
"literal_with_metadata_fn": pa.array([3]),
872+
},
873+
schema=expected_schema,
874+
)
875+
876+
assert result[0] == expected
877+
878+
# Testing result[0].schema == expected_schema does not check each key/value pair
879+
# so we want to explicitly test these
880+
for expected_field in expected_schema:
881+
actual_field = result[0].schema.field(expected_field.name)
882+
assert expected_field.metadata == actual_field.metadata

src/expr.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use datafusion::logical_expr::expr::AggregateFunctionParams;
1919
use datafusion::logical_expr::utils::exprlist_to_fields;
2020
use datafusion::logical_expr::{
21-
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
21+
lit_with_metadata, ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
2222
};
2323
use pyo3::IntoPyObjectExt;
2424
use pyo3::{basic::CompareOp, prelude::*};
@@ -282,6 +282,14 @@ impl PyExpr {
282282
lit(value.0).into()
283283
}
284284

285+
#[staticmethod]
286+
pub fn literal_with_metadata(
287+
value: PyScalarValue,
288+
metadata: HashMap<String, String>,
289+
) -> PyExpr {
290+
lit_with_metadata(value.0, metadata).into()
291+
}
292+
285293
#[staticmethod]
286294
pub fn column(value: &str) -> PyExpr {
287295
col(value).into()

0 commit comments

Comments
 (0)