1717
1818from __future__ import annotations
1919
20+ import typing
2021from abc import ABC , abstractmethod
2122from functools import cached_property
2223from typing import (
2324 Any ,
2425 Callable ,
26+ ClassVar ,
2527 Generic ,
2628 Iterable ,
2729 Sequence ,
3234 Union ,
3335)
3436
35- from pydantic import Field
37+ from pydantic import ConfigDict , Field , field_serializer , field_validator
3638
3739from pyiceberg .expressions .literals import (
3840 AboveMax ,
@@ -727,45 +729,37 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
727729 return BoundNotIn [L ]
728730
729731
730- class LiteralPredicate (UnboundPredicate [L ], ABC ):
731- literal : Literal [L ]
732+ class LiteralPredicate (IcebergBaseModel , UnboundPredicate [L ], ABC ):
733+ op : str = Field (
734+ default = "" ,
735+ alias = "type" ,
736+ validation_alias = "type" ,
737+ serialization_alias = "type" ,
738+ repr = False ,
739+ )
740+ term : Term [L ]
741+ literal : Literal [L ] = Field (serialization_alias = "value" )
742+
743+ __op__ : ClassVar [str ] = ""
744+
745+ model_config = ConfigDict (arbitrary_types_allowed = True )
732746
733747 def __init__ (self , term : Union [str , UnboundTerm [Any ]], literal : Union [L , Literal [L ]]): # pylint: disable=W0621
734- super ().__init__ (term )
735- self .literal = _to_literal (literal ) # pylint: disable=W0621
736-
737- # ---- JSON (Pydantic) serialization helpers ----
738-
739- class _LiteralPredicateModel (IcebergBaseModel ):
740- type : str = Field (alias = "type" )
741- term : str
742- value : Any
743-
744- def _json_op (self ) -> str :
745- mapping = {
746- EqualTo : "eq" ,
747- NotEqualTo : "not-eq" ,
748- LessThan : "lt" ,
749- LessThanOrEqual : "lt-eq" ,
750- GreaterThan : "gt" ,
751- GreaterThanOrEqual : "gt-eq" ,
752- StartsWith : "starts-with" ,
753- NotStartsWith : "not-starts-with" ,
754- }
755- for cls , op in mapping .items ():
756- if isinstance (self , cls ):
757- return op
758- raise ValueError (f"Unknown LiteralPredicate: { type (self ).__name__ } " )
759-
760- def model_dump (self , ** kwargs : Any ) -> dict :
761- term_name = getattr (self .term , "name" , str (self .term ))
762- return self ._LiteralPredicateModel (type = self ._json_op (), term = term_name , value = self .literal .value ).model_dump (** kwargs )
763-
764- def model_dump_json (self , ** kwargs : Any ) -> str :
765- term_name = getattr (self .term , "name" , str (self .term ))
766- return self ._LiteralPredicateModel (type = self ._json_op (), term = term_name , value = self .literal .value ).model_dump_json (
767- ** kwargs
768- )
748+ super ().__init__ (term = _to_unbound_term (term ), literal = _to_literal (literal ))
749+
750+ def model_post_init (self , __context : Any ) -> None :
751+ if not self .op :
752+ object .__setattr__ (self , "op" , self .__op__ )
753+ elif self .op != self .__op__ :
754+ raise ValueError (f"Invalid type { self .op !r} ; expected { self .__op__ !r} " )
755+
756+ @field_serializer ("term" )
757+ def ser_term (self , v : Term [L ]) -> str :
758+ return v .name
759+
760+ @field_serializer ("literal" )
761+ def ser_literal (self , literal : Literal [L ]) -> str :
762+ return "Any"
769763
770764 def bind (self , schema : Schema , case_sensitive : bool = True ) -> BoundLiteralPredicate [L ]:
771765 bound_term = self .term .bind (schema , case_sensitive )
@@ -790,6 +784,10 @@ def __eq__(self, other: Any) -> bool:
790784 return self .term == other .term and self .literal == other .literal
791785 return False
792786
787+ def __str__ (self ) -> str :
788+ """Return the string representation of the LiteralPredicate class."""
789+ return f"{ str (self .__class__ .__name__ )} (term={ repr (self .term )} , literal={ repr (self .literal )} )"
790+
793791 def __repr__ (self ) -> str :
794792 """Return the string representation of the LiteralPredicate class."""
795793 return f"{ str (self .__class__ .__name__ )} (term={ repr (self .term )} , literal={ repr (self .literal )} )"
@@ -903,6 +901,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]:
903901
904902
905903class EqualTo (LiteralPredicate [L ]):
904+ __op__ = "eq"
905+
906906 def __invert__ (self ) -> NotEqualTo [L ]:
907907 """Transform the Expression into its negated version."""
908908 return NotEqualTo [L ](self .term , self .literal )
@@ -913,6 +913,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]:
913913
914914
915915class NotEqualTo (LiteralPredicate [L ]):
916+ __op__ = "not-eq"
917+
916918 def __invert__ (self ) -> EqualTo [L ]:
917919 """Transform the Expression into its negated version."""
918920 return EqualTo [L ](self .term , self .literal )
@@ -923,6 +925,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]:
923925
924926
925927class LessThan (LiteralPredicate [L ]):
928+ __op__ = "lt"
929+
926930 def __invert__ (self ) -> GreaterThanOrEqual [L ]:
927931 """Transform the Expression into its negated version."""
928932 return GreaterThanOrEqual [L ](self .term , self .literal )
@@ -933,6 +937,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]:
933937
934938
935939class GreaterThanOrEqual (LiteralPredicate [L ]):
940+ __op__ = "gt-eq"
941+
936942 def __invert__ (self ) -> LessThan [L ]:
937943 """Transform the Expression into its negated version."""
938944 return LessThan [L ](self .term , self .literal )
@@ -943,6 +949,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]:
943949
944950
945951class GreaterThan (LiteralPredicate [L ]):
952+ __op__ = "gt"
953+
946954 def __invert__ (self ) -> LessThanOrEqual [L ]:
947955 """Transform the Expression into its negated version."""
948956 return LessThanOrEqual [L ](self .term , self .literal )
@@ -953,6 +961,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]:
953961
954962
955963class LessThanOrEqual (LiteralPredicate [L ]):
964+ __op__ = "lt-eq"
965+
956966 def __invert__ (self ) -> GreaterThan [L ]:
957967 """Transform the Expression into its negated version."""
958968 return GreaterThan [L ](self .term , self .literal )
@@ -963,6 +973,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]:
963973
964974
965975class StartsWith (LiteralPredicate [L ]):
976+ __op__ = "starts-with"
977+
966978 def __invert__ (self ) -> NotStartsWith [L ]:
967979 """Transform the Expression into its negated version."""
968980 return NotStartsWith [L ](self .term , self .literal )
@@ -973,6 +985,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]:
973985
974986
975987class NotStartsWith (LiteralPredicate [L ]):
988+ __op__ = "not-starts-with"
989+
976990 def __invert__ (self ) -> StartsWith [L ]:
977991 """Transform the Expression into its negated version."""
978992 return StartsWith [L ](self .term , self .literal )
0 commit comments