Skip to content

Commit c1f7384

Browse files
committed
feat: subclass LiteralPredicate instead of using internal class
1 parent 9257a6d commit c1f7384

2 files changed

Lines changed: 68 additions & 54 deletions

File tree

pyiceberg/expressions/__init__.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
from __future__ import annotations
1919

20+
import typing
2021
from abc import ABC, abstractmethod
2122
from functools import cached_property
2223
from typing import (
2324
Any,
2425
Callable,
26+
ClassVar,
2527
Generic,
2628
Iterable,
2729
Sequence,
@@ -35,7 +37,7 @@
3537

3638
from pydantic import Field
3739

38-
from pydantic import Field
40+
from pydantic import ConfigDict, Field, field_serializer, field_validator
3941

4042
from pyiceberg.expressions.literals import (
4143
AboveMax,
@@ -745,45 +747,37 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
745747
return BoundNotIn[L]
746748

747749

748-
class LiteralPredicate(UnboundPredicate[L], ABC):
749-
literal: Literal[L]
750+
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
751+
op: str = Field(
752+
default="",
753+
alias="type",
754+
validation_alias="type",
755+
serialization_alias="type",
756+
repr=False,
757+
)
758+
term: Term[L]
759+
literal: Literal[L] = Field(serialization_alias="value")
760+
761+
__op__: ClassVar[str] = ""
762+
763+
model_config = ConfigDict(arbitrary_types_allowed=True)
750764

751765
def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621
752-
super().__init__(term)
753-
self.literal = _to_literal(literal) # pylint: disable=W0621
754-
755-
# ---- JSON (Pydantic) serialization helpers ----
756-
757-
class _LiteralPredicateModel(IcebergBaseModel):
758-
type: str = Field(alias="type")
759-
term: str
760-
value: Any
761-
762-
def _json_op(self) -> str:
763-
mapping = {
764-
EqualTo: "eq",
765-
NotEqualTo: "not-eq",
766-
LessThan: "lt",
767-
LessThanOrEqual: "lt-eq",
768-
GreaterThan: "gt",
769-
GreaterThanOrEqual: "gt-eq",
770-
StartsWith: "starts-with",
771-
NotStartsWith: "not-starts-with",
772-
}
773-
for cls, op in mapping.items():
774-
if isinstance(self, cls):
775-
return op
776-
raise ValueError(f"Unknown LiteralPredicate: {type(self).__name__}")
777-
778-
def model_dump(self, **kwargs: Any) -> dict:
779-
term_name = getattr(self.term, "name", str(self.term))
780-
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump(**kwargs)
781-
782-
def model_dump_json(self, **kwargs: Any) -> str:
783-
term_name = getattr(self.term, "name", str(self.term))
784-
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump_json(
785-
**kwargs
786-
)
766+
super().__init__(term=_to_unbound_term(term), literal=_to_literal(literal))
767+
768+
def model_post_init(self, __context: Any) -> None:
769+
if not self.op:
770+
object.__setattr__(self, "op", self.__op__)
771+
elif self.op != self.__op__:
772+
raise ValueError(f"Invalid type {self.op!r}; expected {self.__op__!r}")
773+
774+
@field_serializer("term")
775+
def ser_term(self, v: Term[L]) -> str:
776+
return v.name
777+
778+
@field_serializer("literal")
779+
def ser_literal(self, literal: Literal[L]) -> str:
780+
return "Any"
787781

788782
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
789783
bound_term = self.term.bind(schema, case_sensitive)
@@ -808,6 +802,10 @@ def __eq__(self, other: Any) -> bool:
808802
return self.term == other.term and self.literal == other.literal
809803
return False
810804

805+
def __str__(self) -> str:
806+
"""Return the string representation of the LiteralPredicate class."""
807+
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
808+
811809
def __repr__(self) -> str:
812810
"""Return the string representation of the LiteralPredicate class."""
813811
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
@@ -921,6 +919,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]:
921919

922920

923921
class EqualTo(LiteralPredicate[L]):
922+
__op__ = "eq"
923+
924924
def __invert__(self) -> NotEqualTo[L]:
925925
"""Transform the Expression into its negated version."""
926926
return NotEqualTo[L](self.term, self.literal)
@@ -931,6 +931,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]:
931931

932932

933933
class NotEqualTo(LiteralPredicate[L]):
934+
__op__ = "not-eq"
935+
934936
def __invert__(self) -> EqualTo[L]:
935937
"""Transform the Expression into its negated version."""
936938
return EqualTo[L](self.term, self.literal)
@@ -941,6 +943,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]:
941943

942944

943945
class LessThan(LiteralPredicate[L]):
946+
__op__ = "lt"
947+
944948
def __invert__(self) -> GreaterThanOrEqual[L]:
945949
"""Transform the Expression into its negated version."""
946950
return GreaterThanOrEqual[L](self.term, self.literal)
@@ -951,6 +955,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]:
951955

952956

953957
class GreaterThanOrEqual(LiteralPredicate[L]):
958+
__op__ = "gt-eq"
959+
954960
def __invert__(self) -> LessThan[L]:
955961
"""Transform the Expression into its negated version."""
956962
return LessThan[L](self.term, self.literal)
@@ -961,6 +967,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]:
961967

962968

963969
class GreaterThan(LiteralPredicate[L]):
970+
__op__ = "gt"
971+
964972
def __invert__(self) -> LessThanOrEqual[L]:
965973
"""Transform the Expression into its negated version."""
966974
return LessThanOrEqual[L](self.term, self.literal)
@@ -971,6 +979,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]:
971979

972980

973981
class LessThanOrEqual(LiteralPredicate[L]):
982+
__op__ = "lt-eq"
983+
974984
def __invert__(self) -> GreaterThan[L]:
975985
"""Transform the Expression into its negated version."""
976986
return GreaterThan[L](self.term, self.literal)
@@ -981,6 +991,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]:
981991

982992

983993
class StartsWith(LiteralPredicate[L]):
994+
__op__ = "starts-with"
995+
984996
def __invert__(self) -> NotStartsWith[L]:
985997
"""Transform the Expression into its negated version."""
986998
return NotStartsWith[L](self.term, self.literal)
@@ -991,6 +1003,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]:
9911003

9921004

9931005
class NotStartsWith(LiteralPredicate[L]):
1006+
__op__ = "not-starts-with"
1007+
9941008
def __invert__(self) -> StartsWith[L]:
9951009
"""Transform the Expression into its negated version."""
9961010
return StartsWith[L](self.term, self.literal)

tests/expressions/test_expressions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:
935935

936936
def test_equal_to() -> None:
937937
equal_to = EqualTo(Reference("a"), literal("a"))
938-
assert equal_to.model_dump_json() == '{"type":"eq","term":"a","value":"a"}'
938+
assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"Any"}'
939939
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
940940
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
941941
assert equal_to == eval(repr(equal_to))
@@ -944,7 +944,7 @@ def test_equal_to() -> None:
944944

945945
def test_not_equal_to() -> None:
946946
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
947-
assert not_equal_to.model_dump_json() == '{"type":"not-eq","term":"a","value":"a"}'
947+
assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"Any"}'
948948
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
949949
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
950950
assert not_equal_to == eval(repr(not_equal_to))
@@ -953,7 +953,7 @@ def test_not_equal_to() -> None:
953953

954954
def test_greater_than_or_equal_to() -> None:
955955
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
956-
assert greater_than_or_equal_to.model_dump_json() == '{"type":"gt-eq","term":"a","value":"a"}'
956+
assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"Any"}'
957957
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
958958
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
959959
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
@@ -962,7 +962,7 @@ def test_greater_than_or_equal_to() -> None:
962962

963963
def test_greater_than() -> None:
964964
greater_than = GreaterThan(Reference("a"), literal("a"))
965-
assert greater_than.model_dump_json() == '{"type":"gt","term":"a","value":"a"}'
965+
assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"Any"}'
966966
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
967967
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
968968
assert greater_than == eval(repr(greater_than))
@@ -971,32 +971,32 @@ def test_greater_than() -> None:
971971

972972
def test_less_than() -> None:
973973
less_than = LessThan(Reference("a"), literal("a"))
974-
assert less_than.model_dump_json() == '{"type":"lt","term":"a","value":"a"}'
974+
assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"Any"}'
975975
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
976976
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
977977
assert less_than == eval(repr(less_than))
978978
assert less_than == pickle.loads(pickle.dumps(less_than))
979979

980980

981-
def test_starts_with() -> None:
982-
starts_with = StartsWith(Reference("a"), literal("a"))
983-
assert starts_with.model_dump_json() == '{"type":"starts-with","term":"a","value":"a"}'
984-
985-
986-
def test_not_starts_with() -> None:
987-
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
988-
assert not_starts_with.model_dump_json() == '{"type":"not-starts-with","term":"a","value":"a"}'
989-
990-
991981
def test_less_than_or_equal() -> None:
992982
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
993-
assert less_than_or_equal.model_dump_json() == '{"type":"lt-eq","term":"a","value":"a"}'
983+
assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"Any"}'
994984
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
995985
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
996986
assert less_than_or_equal == eval(repr(less_than_or_equal))
997987
assert less_than_or_equal == pickle.loads(pickle.dumps(less_than_or_equal))
998988

999989

990+
def test_starts_with() -> None:
991+
starts_with = StartsWith(Reference("a"), literal("a"))
992+
assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"Any"}'
993+
994+
995+
def test_not_starts_with() -> None:
996+
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
997+
assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"Any"}'
998+
999+
10001000
def test_bound_reference_eval(table_schema_simple: Schema) -> None:
10011001
"""Test creating a BoundReference and evaluating it on a StructProtocol"""
10021002
struct = Record("foovalue", 123, True)

0 commit comments

Comments
 (0)