Skip to content

Commit 796de63

Browse files
committed
feat: make LiteralPredicate serializable via internal IcebergBaseModel
1 parent 5ee5eea commit 796de63

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@
3232
Union,
3333
)
3434

35+
from pydantic import Field
36+
3537
from pyiceberg.expressions.literals import (
3638
AboveMax,
3739
BelowMin,
3840
Literal,
3941
literal,
4042
)
4143
from pyiceberg.schema import Accessor, Schema
42-
from pyiceberg.typedef import IcebergRootModel, L, StructProtocol
44+
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, StructProtocol
4345
from pyiceberg.types import DoubleType, FloatType, NestedField
4446
from pyiceberg.utils.singleton import Singleton
4547

@@ -732,6 +734,39 @@ def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal
732734
super().__init__(term)
733735
self.literal = _to_literal(literal) # pylint: disable=W0621
734736

737+
# ---- JSON (Pydantic) serialization helpers ----
738+
739+
class _LiteralPredicateModel(IcebergBaseModel):
740+
type: str = Field(alias="type")
741+
term: str
742+
value: Any
743+
744+
def _json_op(self) -> str:
745+
mapping = {
746+
EqualTo: "eq",
747+
NotEqualTo: "not-eq",
748+
LessThan: "lt",
749+
LessThanOrEqual: "lt-eq",
750+
GreaterThan: "gt",
751+
GreaterThanOrEqual: "gt-eq",
752+
StartsWith: "starts-with",
753+
NotStartsWith: "not-starts-with",
754+
}
755+
for cls, op in mapping.items():
756+
if isinstance(self, cls):
757+
return op
758+
raise ValueError(f"Unknown LiteralPredicate: {type(self).__name__}")
759+
760+
def model_dump(self, **kwargs: Any) -> dict:
761+
term_name = getattr(self.term, "name", str(self.term))
762+
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump(**kwargs)
763+
764+
def model_dump_json(self, **kwargs: Any) -> str:
765+
term_name = getattr(self.term, "name", str(self.term))
766+
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump_json(
767+
**kwargs
768+
)
769+
735770
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
736771
bound_term = self.term.bind(schema, case_sensitive)
737772
lit = self.literal.to(bound_term.ref().field.field_type)

tests/expressions/test_expressions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@
5555
NotIn,
5656
NotNaN,
5757
NotNull,
58+
NotStartsWith,
5859
Or,
5960
Reference,
61+
StartsWith,
6062
UnboundPredicate,
6163
)
6264
from pyiceberg.expressions.literals import Literal, literal
@@ -921,6 +923,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:
921923

922924
def test_equal_to() -> None:
923925
equal_to = EqualTo(Reference("a"), literal("a"))
926+
assert equal_to.model_dump_json() == '{"type":"eq","term":"a","value":"a"}'
924927
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
925928
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
926929
assert equal_to == eval(repr(equal_to))
@@ -929,6 +932,7 @@ def test_equal_to() -> None:
929932

930933
def test_not_equal_to() -> None:
931934
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
935+
assert not_equal_to.model_dump_json() == '{"type":"not-eq","term":"a","value":"a"}'
932936
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
933937
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
934938
assert not_equal_to == eval(repr(not_equal_to))
@@ -937,6 +941,7 @@ def test_not_equal_to() -> None:
937941

938942
def test_greater_than_or_equal_to() -> None:
939943
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
944+
assert greater_than_or_equal_to.model_dump_json() == '{"type":"gt-eq","term":"a","value":"a"}'
940945
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
941946
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
942947
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
@@ -945,6 +950,7 @@ def test_greater_than_or_equal_to() -> None:
945950

946951
def test_greater_than() -> None:
947952
greater_than = GreaterThan(Reference("a"), literal("a"))
953+
assert greater_than.model_dump_json() == '{"type":"gt","term":"a","value":"a"}'
948954
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
949955
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
950956
assert greater_than == eval(repr(greater_than))
@@ -953,14 +959,26 @@ def test_greater_than() -> None:
953959

954960
def test_less_than() -> None:
955961
less_than = LessThan(Reference("a"), literal("a"))
962+
assert less_than.model_dump_json() == '{"type":"lt","term":"a","value":"a"}'
956963
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
957964
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
958965
assert less_than == eval(repr(less_than))
959966
assert less_than == pickle.loads(pickle.dumps(less_than))
960967

961968

969+
def test_starts_with() -> None:
970+
starts_with = StartsWith(Reference("a"), literal("a"))
971+
assert starts_with.model_dump_json() == '{"type":"starts-with","term":"a","value":"a"}'
972+
973+
974+
def test_not_starts_with() -> None:
975+
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
976+
assert not_starts_with.model_dump_json() == '{"type":"not-starts-with","term":"a","value":"a"}'
977+
978+
962979
def test_less_than_or_equal() -> None:
963980
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
981+
assert less_than_or_equal.model_dump_json() == '{"type":"lt-eq","term":"a","value":"a"}'
964982
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
965983
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
966984
assert less_than_or_equal == eval(repr(less_than_or_equal))

0 commit comments

Comments
 (0)