Skip to content

Commit 35c83c7

Browse files
committed
fix: mypy errors and tests
1 parent 5118748 commit 35c83c7

File tree

4 files changed

+80
-50
lines changed

4 files changed

+80
-50
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
Type,
3131
TypeVar,
3232
Union,
33+
cast,
3334
)
3435
from typing import Literal as TypingLiteral
3536

36-
from pydantic import ConfigDict, Field, field_serializer, field_validator
37+
from pydantic import ConfigDict, Field, field_validator
3738

3839
from pyiceberg.expressions.literals import (
3940
AboveMax,
@@ -751,31 +752,50 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
751752

752753
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
753754
type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type")
754-
term: UnboundTerm[L]
755-
literal: Literal[L] = Field(serialization_alias="value")
755+
term: UnboundTerm[Any]
756+
value: Literal[L] = Field(alias="literal", serialization_alias="value")
756757

757-
model_config = ConfigDict(arbitrary_types_allowed=True)
758+
model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True)
759+
760+
def __init__(
761+
self,
762+
term: Union[str, UnboundTerm[Any], BoundReference[Any]],
763+
literal: Union[L, Literal[L], None] = None,
764+
**data: Any,
765+
) -> None: # pylint: disable=W0621
766+
extra = dict(data)
767+
768+
literal_candidates = []
769+
if literal is not None:
770+
literal_candidates.append(literal)
771+
if "literal" in extra:
772+
literal_candidates.append(extra.pop("literal"))
773+
if "value" in extra:
774+
literal_candidates.append(extra.pop("value"))
758775

759-
def __init__(self, *args: Any, **kwargs: Any) -> None:
760-
if args:
761-
if len(args) != 2:
762-
raise TypeError("Expected (term, literal)")
763-
kwargs = {"term": args[0], "literal": args[1], **kwargs}
764-
super().__init__(**kwargs)
776+
literal_candidates = [candidate for candidate in literal_candidates if candidate is not None]
777+
778+
if not literal_candidates:
779+
raise TypeError("LiteralPredicate requires a literal or value argument")
780+
if len(literal_candidates) > 1:
781+
raise TypeError("literal/value provided multiple times")
782+
783+
init = cast("Callable[..., None]", IcebergBaseModel.__init__)
784+
init(self, term=_to_unbound_term(term), literal=_to_literal(literal_candidates[0]), **extra)
765785

766786
@field_validator("term", mode="before")
767787
@classmethod
768-
def _coerce_term(cls, v: Any) -> UnboundTerm[Any]:
769-
return _to_unbound_term(v)
788+
def _convert_term(cls, value: Any) -> UnboundTerm[Any]:
789+
return _to_unbound_term(value)
770790

771-
@field_validator("literal", mode="before")
791+
@field_validator("value", mode="before")
772792
@classmethod
773-
def _coerce_literal(cls, v: Union[L, Literal[L]]) -> Literal[L]:
774-
return _to_literal(v)
793+
def _convert_value(cls, value: Any) -> Literal[Any]:
794+
return _to_literal(value)
775795

776-
@field_serializer("literal")
777-
def ser_literal(self, literal: Literal[L]) -> str:
778-
return "Any"
796+
@property
797+
def literal(self) -> Literal[L]:
798+
return self.value
779799

780800
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
781801
bound_term = self.term.bind(schema, case_sensitive)

pyiceberg/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _try_import(module_name: str, extras_name: Optional[str] = None) -> types.Mo
120120
raise NotInstalledError(msg) from None
121121

122122

123-
def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
123+
def _transform_literal(func: Callable[[Any], Any], lit: Literal[L]) -> Literal[L]:
124124
"""Small helper to upwrap the value from the literal, and wrap it again."""
125125
return literal(func(lit.value))
126126

tests/expressions/test_evaluator.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pyiceberg.conversions import to_bytes
2323
from pyiceberg.expressions import (
2424
And,
25+
BooleanExpression,
2526
EqualTo,
2627
GreaterThan,
2728
GreaterThanOrEqual,
@@ -30,6 +31,7 @@
3031
IsNull,
3132
LessThan,
3233
LessThanOrEqual,
34+
LiteralPredicate,
3335
Not,
3436
NotEqualTo,
3537
NotIn,
@@ -301,7 +303,7 @@ def test_missing_stats() -> None:
301303
upper_bounds=None,
302304
)
303305

304-
expressions = [
306+
expressions: list[BooleanExpression] = [
305307
LessThan("no_stats", 5),
306308
LessThanOrEqual("no_stats", 30),
307309
EqualTo("no_stats", 70),
@@ -324,7 +326,7 @@ def test_zero_record_file_stats(schema_data_file: Schema) -> None:
324326
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
325327
)
326328

327-
expressions = [
329+
expressions: list[BooleanExpression] = [
328330
LessThan("no_stats", 5),
329331
LessThanOrEqual("no_stats", 30),
330332
EqualTo("no_stats", 70),
@@ -683,26 +685,27 @@ def data_file_nan() -> DataFile:
683685

684686

685687
def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None:
686-
for operator in [LessThan, LessThanOrEqual]:
687-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
688+
operators: tuple[type[LiteralPredicate[Any]], ...] = (LessThan, LessThanOrEqual)
689+
for operator in operators:
690+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
688691
assert not should_read, "Should not match: all nan column doesn't contain number"
689692

690-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
693+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
691694
assert not should_read, "Should not match: 1 is smaller than lower bound"
692695

693-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
696+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
694697
assert should_read, "Should match: 10 is larger than lower bound"
695698

696-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
699+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
697700
assert should_read, "Should match: no visibility"
698701

699-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
702+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
700703
assert not should_read, "Should not match: all nan column doesn't contain number"
701704

702-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
705+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
703706
assert not should_read, "Should not match: 1 is smaller than lower bound"
704707

705-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
708+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
706709
data_file_nan
707710
)
708711
assert should_read, "Should match: 10 larger than lower bound"
@@ -711,31 +714,32 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f
711714
def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal(
712715
schema_data_file_nan: Schema, data_file_nan: DataFile
713716
) -> None:
714-
for operator in [GreaterThan, GreaterThanOrEqual]:
715-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
717+
operators: tuple[type[LiteralPredicate[Any]], ...] = (GreaterThan, GreaterThanOrEqual)
718+
for operator in operators:
719+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
716720
assert not should_read, "Should not match: all nan column doesn't contain number"
717721

718-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
722+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
719723
assert should_read, "Should match: upper bound is larger than 1"
720724

721-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
725+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
722726
assert should_read, "Should match: upper bound is larger than 10"
723727

724-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
728+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
725729
assert should_read, "Should match: no visibility"
726730

727-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
731+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
728732
assert not should_read, "Should not match: all nan column doesn't contain number"
729733

730-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
734+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
731735
assert should_read, "Should match: 1 is smaller than upper bound"
732736

733-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
737+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
734738
data_file_nan
735739
)
736740
assert should_read, "Should match: 10 is smaller than upper bound"
737741

738-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan) # type: ignore[arg-type]
742+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan)
739743
assert not should_read, "Should not match: 30 is greater than upper bound"
740744

741745

@@ -1162,7 +1166,7 @@ def test_strict_missing_stats(strict_data_file_schema: Schema, strict_data_file_
11621166
upper_bounds=None,
11631167
)
11641168

1165-
expressions = [
1169+
expressions: list[BooleanExpression] = [
11661170
LessThan("no_stats", 5),
11671171
LessThanOrEqual("no_stats", 30),
11681172
EqualTo("no_stats", 70),
@@ -1185,7 +1189,7 @@ def test_strict_zero_record_file_stats(strict_data_file_schema: Schema) -> None:
11851189
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
11861190
)
11871191

1188-
expressions = [
1192+
expressions: list[BooleanExpression] = [
11891193
LessThan("no_stats", 5),
11901194
LessThanOrEqual("no_stats", 30),
11911195
EqualTo("no_stats", 70),

tests/expressions/test_expressions.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
IsNull,
5151
LessThan,
5252
LessThanOrEqual,
53+
LiteralPredicate,
5354
Not,
5455
NotEqualTo,
5556
NotIn,
@@ -64,7 +65,7 @@
6465
from pyiceberg.expressions.literals import Literal, literal
6566
from pyiceberg.expressions.visitors import _from_byte_buffer
6667
from pyiceberg.schema import Accessor, Schema
67-
from pyiceberg.typedef import Record
68+
from pyiceberg.typedef import L, Record
6869
from pyiceberg.types import (
6970
DecimalType,
7071
DoubleType,
@@ -935,7 +936,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:
935936

936937
def test_equal_to() -> None:
937938
equal_to = EqualTo(Reference("a"), literal("a"))
938-
assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"Any"}'
939+
assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"a"}'
939940
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
940941
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
941942
assert equal_to == eval(repr(equal_to))
@@ -944,7 +945,7 @@ def test_equal_to() -> None:
944945

945946
def test_not_equal_to() -> None:
946947
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
947-
assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"Any"}'
948+
assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"a"}'
948949
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
949950
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
950951
assert not_equal_to == eval(repr(not_equal_to))
@@ -953,7 +954,7 @@ def test_not_equal_to() -> None:
953954

954955
def test_greater_than_or_equal_to() -> None:
955956
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
956-
assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"Any"}'
957+
assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"a"}'
957958
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
958959
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
959960
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
@@ -962,7 +963,7 @@ def test_greater_than_or_equal_to() -> None:
962963

963964
def test_greater_than() -> None:
964965
greater_than = GreaterThan(Reference("a"), literal("a"))
965-
assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"Any"}'
966+
assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"a"}'
966967
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
967968
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
968969
assert greater_than == eval(repr(greater_than))
@@ -971,7 +972,7 @@ def test_greater_than() -> None:
971972

972973
def test_less_than() -> None:
973974
less_than = LessThan(Reference("a"), literal("a"))
974-
assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"Any"}'
975+
assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"a"}'
975976
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
976977
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
977978
assert less_than == eval(repr(less_than))
@@ -980,7 +981,7 @@ def test_less_than() -> None:
980981

981982
def test_less_than_or_equal() -> None:
982983
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
983-
assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"Any"}'
984+
assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"a"}'
984985
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
985986
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
986987
assert less_than_or_equal == eval(repr(less_than_or_equal))
@@ -989,12 +990,12 @@ def test_less_than_or_equal() -> None:
989990

990991
def test_starts_with() -> None:
991992
starts_with = StartsWith(Reference("a"), literal("a"))
992-
assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"Any"}'
993+
assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"a"}'
993994

994995

995996
def test_not_starts_with() -> None:
996997
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
997-
assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"Any"}'
998+
assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"a"}'
998999

9991000

10001001
def test_bound_reference_eval(table_schema_simple: Schema) -> None:
@@ -1235,7 +1236,12 @@ def test_bind_ambiguous_name() -> None:
12351236
# |_| |_|\_, |_| \_, |
12361237
# |__/ |__/
12371238

1238-
assert_type(EqualTo("a", "b"), EqualTo[str])
1239+
1240+
def _assert_literal_predicate_type(expr: LiteralPredicate[L]) -> None:
1241+
assert_type(expr, LiteralPredicate[L])
1242+
1243+
1244+
_assert_literal_predicate_type(EqualTo("a", "b"))
12391245
assert_type(In("a", ("a", "b", "c")), In[str])
12401246
assert_type(In("a", (1, 2, 3)), In[int])
12411247
assert_type(NotIn("a", ("a", "b", "c")), NotIn[str])

0 commit comments

Comments
 (0)