Skip to content

Commit 2ebc03d

Browse files
committed
feat: subclass LiteralPredicate instead of using internal class
1 parent 796de63 commit 2ebc03d

File tree

2 files changed

+68
-54
lines changed

2 files changed

+68
-54
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
from __future__ import annotations
1919

20+
import typing
2021
from abc import ABC, abstractmethod
2122
from functools import cached_property
2223
from typing import (
2324
Any,
2425
Callable,
26+
ClassVar,
2527
Generic,
2628
Iterable,
2729
Sequence,
@@ -32,7 +34,7 @@
3234
Union,
3335
)
3436

35-
from pydantic import Field
37+
from pydantic import ConfigDict, Field, field_serializer, field_validator
3638

3739
from pyiceberg.expressions.literals import (
3840
AboveMax,
@@ -727,45 +729,37 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
727729
return BoundNotIn[L]
728730

729731

730-
class LiteralPredicate(UnboundPredicate[L], ABC):
731-
literal: Literal[L]
732+
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
733+
op: str = Field(
734+
default="",
735+
alias="type",
736+
validation_alias="type",
737+
serialization_alias="type",
738+
repr=False,
739+
)
740+
term: Term[L]
741+
literal: Literal[L] = Field(serialization_alias="value")
742+
743+
__op__: ClassVar[str] = ""
744+
745+
model_config = ConfigDict(arbitrary_types_allowed=True)
732746

733747
def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621
734-
super().__init__(term)
735-
self.literal = _to_literal(literal) # pylint: disable=W0621
736-
737-
# ---- JSON (Pydantic) serialization helpers ----
738-
739-
class _LiteralPredicateModel(IcebergBaseModel):
740-
type: str = Field(alias="type")
741-
term: str
742-
value: Any
743-
744-
def _json_op(self) -> str:
745-
mapping = {
746-
EqualTo: "eq",
747-
NotEqualTo: "not-eq",
748-
LessThan: "lt",
749-
LessThanOrEqual: "lt-eq",
750-
GreaterThan: "gt",
751-
GreaterThanOrEqual: "gt-eq",
752-
StartsWith: "starts-with",
753-
NotStartsWith: "not-starts-with",
754-
}
755-
for cls, op in mapping.items():
756-
if isinstance(self, cls):
757-
return op
758-
raise ValueError(f"Unknown LiteralPredicate: {type(self).__name__}")
759-
760-
def model_dump(self, **kwargs: Any) -> dict:
761-
term_name = getattr(self.term, "name", str(self.term))
762-
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump(**kwargs)
763-
764-
def model_dump_json(self, **kwargs: Any) -> str:
765-
term_name = getattr(self.term, "name", str(self.term))
766-
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump_json(
767-
**kwargs
768-
)
748+
super().__init__(term=_to_unbound_term(term), literal=_to_literal(literal))
749+
750+
def model_post_init(self, __context: Any) -> None:
751+
if not self.op:
752+
object.__setattr__(self, "op", self.__op__)
753+
elif self.op != self.__op__:
754+
raise ValueError(f"Invalid type {self.op!r}; expected {self.__op__!r}")
755+
756+
@field_serializer("term")
757+
def ser_term(self, v: Term[L]) -> str:
758+
return v.name
759+
760+
@field_serializer("literal")
761+
def ser_literal(self, literal: Literal[L]) -> str:
762+
return "Any"
769763

770764
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
771765
bound_term = self.term.bind(schema, case_sensitive)
@@ -790,6 +784,10 @@ def __eq__(self, other: Any) -> bool:
790784
return self.term == other.term and self.literal == other.literal
791785
return False
792786

787+
def __str__(self) -> str:
788+
"""Return the string representation of the LiteralPredicate class."""
789+
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
790+
793791
def __repr__(self) -> str:
794792
"""Return the string representation of the LiteralPredicate class."""
795793
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
@@ -903,6 +901,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]:
903901

904902

905903
class EqualTo(LiteralPredicate[L]):
904+
__op__ = "eq"
905+
906906
def __invert__(self) -> NotEqualTo[L]:
907907
"""Transform the Expression into its negated version."""
908908
return NotEqualTo[L](self.term, self.literal)
@@ -913,6 +913,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]:
913913

914914

915915
class NotEqualTo(LiteralPredicate[L]):
916+
__op__ = "not-eq"
917+
916918
def __invert__(self) -> EqualTo[L]:
917919
"""Transform the Expression into its negated version."""
918920
return EqualTo[L](self.term, self.literal)
@@ -923,6 +925,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]:
923925

924926

925927
class LessThan(LiteralPredicate[L]):
928+
__op__ = "lt"
929+
926930
def __invert__(self) -> GreaterThanOrEqual[L]:
927931
"""Transform the Expression into its negated version."""
928932
return GreaterThanOrEqual[L](self.term, self.literal)
@@ -933,6 +937,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]:
933937

934938

935939
class GreaterThanOrEqual(LiteralPredicate[L]):
940+
__op__ = "gt-eq"
941+
936942
def __invert__(self) -> LessThan[L]:
937943
"""Transform the Expression into its negated version."""
938944
return LessThan[L](self.term, self.literal)
@@ -943,6 +949,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]:
943949

944950

945951
class GreaterThan(LiteralPredicate[L]):
952+
__op__ = "gt"
953+
946954
def __invert__(self) -> LessThanOrEqual[L]:
947955
"""Transform the Expression into its negated version."""
948956
return LessThanOrEqual[L](self.term, self.literal)
@@ -953,6 +961,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]:
953961

954962

955963
class LessThanOrEqual(LiteralPredicate[L]):
964+
__op__ = "lt-eq"
965+
956966
def __invert__(self) -> GreaterThan[L]:
957967
"""Transform the Expression into its negated version."""
958968
return GreaterThan[L](self.term, self.literal)
@@ -963,6 +973,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]:
963973

964974

965975
class StartsWith(LiteralPredicate[L]):
976+
__op__ = "starts-with"
977+
966978
def __invert__(self) -> NotStartsWith[L]:
967979
"""Transform the Expression into its negated version."""
968980
return NotStartsWith[L](self.term, self.literal)
@@ -973,6 +985,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]:
973985

974986

975987
class NotStartsWith(LiteralPredicate[L]):
988+
__op__ = "not-starts-with"
989+
976990
def __invert__(self) -> StartsWith[L]:
977991
"""Transform the Expression into its negated version."""
978992
return StartsWith[L](self.term, self.literal)

tests/expressions/test_expressions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:
923923

924924
def test_equal_to() -> None:
925925
equal_to = EqualTo(Reference("a"), literal("a"))
926-
assert equal_to.model_dump_json() == '{"type":"eq","term":"a","value":"a"}'
926+
assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"Any"}'
927927
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
928928
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
929929
assert equal_to == eval(repr(equal_to))
@@ -932,7 +932,7 @@ def test_equal_to() -> None:
932932

933933
def test_not_equal_to() -> None:
934934
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
935-
assert not_equal_to.model_dump_json() == '{"type":"not-eq","term":"a","value":"a"}'
935+
assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"Any"}'
936936
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
937937
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
938938
assert not_equal_to == eval(repr(not_equal_to))
@@ -941,7 +941,7 @@ def test_not_equal_to() -> None:
941941

942942
def test_greater_than_or_equal_to() -> None:
943943
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
944-
assert greater_than_or_equal_to.model_dump_json() == '{"type":"gt-eq","term":"a","value":"a"}'
944+
assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"Any"}'
945945
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
946946
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
947947
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
@@ -950,7 +950,7 @@ def test_greater_than_or_equal_to() -> None:
950950

951951
def test_greater_than() -> None:
952952
greater_than = GreaterThan(Reference("a"), literal("a"))
953-
assert greater_than.model_dump_json() == '{"type":"gt","term":"a","value":"a"}'
953+
assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"Any"}'
954954
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
955955
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
956956
assert greater_than == eval(repr(greater_than))
@@ -959,32 +959,32 @@ def test_greater_than() -> None:
959959

960960
def test_less_than() -> None:
961961
less_than = LessThan(Reference("a"), literal("a"))
962-
assert less_than.model_dump_json() == '{"type":"lt","term":"a","value":"a"}'
962+
assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"Any"}'
963963
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
964964
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
965965
assert less_than == eval(repr(less_than))
966966
assert less_than == pickle.loads(pickle.dumps(less_than))
967967

968968

969-
def test_starts_with() -> None:
970-
starts_with = StartsWith(Reference("a"), literal("a"))
971-
assert starts_with.model_dump_json() == '{"type":"starts-with","term":"a","value":"a"}'
972-
973-
974-
def test_not_starts_with() -> None:
975-
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
976-
assert not_starts_with.model_dump_json() == '{"type":"not-starts-with","term":"a","value":"a"}'
977-
978-
979969
def test_less_than_or_equal() -> None:
980970
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
981-
assert less_than_or_equal.model_dump_json() == '{"type":"lt-eq","term":"a","value":"a"}'
971+
assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"Any"}'
982972
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
983973
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
984974
assert less_than_or_equal == eval(repr(less_than_or_equal))
985975
assert less_than_or_equal == pickle.loads(pickle.dumps(less_than_or_equal))
986976

987977

978+
def test_starts_with() -> None:
979+
starts_with = StartsWith(Reference("a"), literal("a"))
980+
assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"Any"}'
981+
982+
983+
def test_not_starts_with() -> None:
984+
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
985+
assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"Any"}'
986+
987+
988988
def test_bound_reference_eval(table_schema_simple: Schema) -> None:
989989
"""Test creating a BoundReference and evaluating it on a StructProtocol"""
990990
struct = Record("foovalue", 123, True)

0 commit comments

Comments
 (0)