Skip to content

Commit dd7e09b

Browse files
committed
fix adding lt literal and allow boundreference in _to_unbound_term
1 parent 482f4e0 commit dd7e09b

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from typing import Literal as TypingLiteral
3535

36-
from pydantic import ConfigDict, Field, field_serializer
36+
from pydantic import ConfigDict, Field, field_serializer, field_validator
3737

3838
from pyiceberg.expressions.literals import (
3939
AboveMax,
@@ -52,8 +52,14 @@
5252
ConfigDict = dict
5353

5454

55-
def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]:
56-
return Reference(term) if isinstance(term, str) else term
55+
def _to_unbound_term(term: Union[str, UnboundTerm[Any], BoundReference[Any]]) -> UnboundTerm[Any]:
56+
if isinstance(term, str):
57+
return Reference(term)
58+
if isinstance(term, UnboundTerm):
59+
return term
60+
if isinstance(term, BoundReference):
61+
return Reference(term.field.name)
62+
raise ValueError(f"Expected UnboundTerm | BoundReference | str, got {type(term).__name__}")
5763

5864

5965
def _to_literal_set(values: Union[Iterable[L], Iterable[Literal[L]]]) -> Set[Literal[L]]:
@@ -744,18 +750,28 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
744750

745751

746752
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
747-
type: TypingLiteral["lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type")
748-
term: Term[L]
753+
type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type")
754+
term: UnboundTerm[L]
749755
literal: Literal[L] = Field(serialization_alias="value")
750756

751757
model_config = ConfigDict(arbitrary_types_allowed=True)
752758

753-
def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621
754-
super().__init__(term=_to_unbound_term(term), literal=_to_literal(literal))
755-
756-
@field_serializer("term")
757-
def ser_term(self, v: Term[L]) -> str:
758-
return v.name
759+
def __init__(self, *args: Any, **kwargs: Any) -> None:
760+
if args:
761+
if len(args) != 2:
762+
raise TypeError("Expected (term, literal)")
763+
kwargs = {"term": args[0], "literal": args[1], **kwargs}
764+
super().__init__(**kwargs)
765+
766+
@field_validator("term", mode="before")
767+
@classmethod
768+
def _coerce_term(cls, v: Any) -> UnboundTerm[Any]:
769+
return _to_unbound_term(v)
770+
771+
@field_validator("literal", mode="before")
772+
@classmethod
773+
def _coerce_literal(cls, v: Union[L, Literal[L]]) -> Literal[L]:
774+
return _to_literal(v)
759775

760776
@field_serializer("literal")
761777
def ser_literal(self, literal: Literal[L]) -> str:

0 commit comments

Comments
 (0)