|
33 | 33 | ) |
34 | 34 | from typing import Literal as TypingLiteral |
35 | 35 |
|
36 | | -from pydantic import ConfigDict, Field, field_serializer |
| 36 | +from pydantic import ConfigDict, Field, field_serializer, field_validator |
37 | 37 |
|
38 | 38 | from pyiceberg.expressions.literals import ( |
39 | 39 | AboveMax, |
|
52 | 52 | ConfigDict = dict |
53 | 53 |
|
54 | 54 |
|
55 | | -def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]: |
56 | | - return Reference(term) if isinstance(term, str) else term |
| 55 | +def _to_unbound_term(term: Union[str, UnboundTerm[Any], BoundReference[Any]]) -> UnboundTerm[Any]: |
| 56 | + if isinstance(term, str): |
| 57 | + return Reference(term) |
| 58 | + if isinstance(term, UnboundTerm): |
| 59 | + return term |
| 60 | + if isinstance(term, BoundReference): |
| 61 | + return Reference(term.field.name) |
| 62 | + raise ValueError(f"Expected UnboundTerm | BoundReference | str, got {type(term).__name__}") |
57 | 63 |
|
58 | 64 |
|
59 | 65 | def _to_literal_set(values: Union[Iterable[L], Iterable[Literal[L]]]) -> Set[Literal[L]]: |
@@ -744,18 +750,28 @@ def as_bound(self) -> Type[BoundNotIn[L]]: |
744 | 750 |
|
745 | 751 |
|
746 | 752 | class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): |
747 | | - type: TypingLiteral["lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") |
748 | | - term: Term[L] |
| 753 | + type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") |
| 754 | + term: UnboundTerm[L] |
749 | 755 | literal: Literal[L] = Field(serialization_alias="value") |
750 | 756 |
|
751 | 757 | model_config = ConfigDict(arbitrary_types_allowed=True) |
752 | 758 |
|
753 | | - def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621 |
754 | | - super().__init__(term=_to_unbound_term(term), literal=_to_literal(literal)) |
755 | | - |
756 | | - @field_serializer("term") |
757 | | - def ser_term(self, v: Term[L]) -> str: |
758 | | - return v.name |
| 759 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 760 | + if args: |
| 761 | + if len(args) != 2: |
| 762 | + raise TypeError("Expected (term, literal)") |
| 763 | + kwargs = {"term": args[0], "literal": args[1], **kwargs} |
| 764 | + super().__init__(**kwargs) |
| 765 | + |
| 766 | + @field_validator("term", mode="before") |
| 767 | + @classmethod |
| 768 | + def _coerce_term(cls, v: Any) -> UnboundTerm[Any]: |
| 769 | + return _to_unbound_term(v) |
| 770 | + |
| 771 | + @field_validator("literal", mode="before") |
| 772 | + @classmethod |
| 773 | + def _coerce_literal(cls, v: Union[L, Literal[L]]) -> Literal[L]: |
| 774 | + return _to_literal(v) |
759 | 775 |
|
760 | 776 | @field_serializer("literal") |
761 | 777 | def ser_literal(self, literal: Literal[L]) -> str: |
|
0 commit comments