diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 84478f24cf..e8c264ac68 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -22,6 +22,7 @@ overload, ) +import annotated_types from pydantic import BaseModel, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( @@ -214,10 +215,10 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, + gt: Optional[annotated_types.SupportsGt] = None, + ge: Optional[annotated_types.SupportsGe] = None, + lt: Optional[annotated_types.SupportsLt] = None, + le: Optional[annotated_types.SupportsLe] = None, multiple_of: Optional[float] = None, max_digits: Optional[int] = None, decimal_places: Optional[int] = None, @@ -257,10 +258,10 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, + gt: Optional[annotated_types.SupportsGt] = None, + ge: Optional[annotated_types.SupportsGe] = None, + lt: Optional[annotated_types.SupportsLt] = None, + le: Optional[annotated_types.SupportsLe] = None, multiple_of: Optional[float] = None, max_digits: Optional[int] = None, decimal_places: Optional[int] = None, @@ -309,10 +310,10 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, + gt: Optional[annotated_types.SupportsGt] = None, + ge: Optional[annotated_types.SupportsGe] = None, + lt: Optional[annotated_types.SupportsLt] = None, + le: Optional[annotated_types.SupportsLe] = None, multiple_of: Optional[float] = None, max_digits: Optional[int] = None, decimal_places: Optional[int] = None, @@ -342,10 +343,10 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, + gt: Optional[annotated_types.SupportsGt] = None, + ge: Optional[annotated_types.SupportsGe] = None, + lt: Optional[annotated_types.SupportsLt] = None, + le: Optional[annotated_types.SupportsLe] = None, multiple_of: Optional[float] = None, max_digits: Optional[int] = None, decimal_places: Optional[int] = None, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 140b02fd9b..c7c67f1fcf 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -54,3 +54,83 @@ class Model(SQLModel): instance = Model(id=123, foo="bar") assert "foo=" not in repr(instance) + + +def test_gt(): + class Model(SQLModel): + int_value: int = Field(gt=10) + tuple_value: tuple[int, int] = Field(gt=(1, 2)) + + Model(int_value=11, tuple_value=(1, 3)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=10, tuple_value=(1, 3)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "greater_than" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=11, tuple_value=(1, 2)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "greater_than" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",) + + +def test_ge(): + class Model(SQLModel): + int_value: int = Field(ge=10) + tuple_value: tuple[int, int] = Field(ge=(1, 2)) + + Model(int_value=10, tuple_value=(1, 2)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=9, tuple_value=(1, 2)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "greater_than_equal" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=10, tuple_value=(1, 1)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "greater_than_equal" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",) + + +def test_lt(): + class Model(SQLModel): + int_value: int = Field(lt=10) + tuple_value: tuple[int, int] = Field(lt=(1, 2)) + + Model(int_value=9, tuple_value=(1, 1)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=10, tuple_value=(1, 1)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "less_than" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=9, tuple_value=(1, 2)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "less_than" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",) + + +def test_le(): + class Model(SQLModel): + int_value: int = Field(le=10) + tuple_value: tuple[int, int] = Field(le=(1, 2)) + + Model(int_value=10, tuple_value=(1, 2)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=11, tuple_value=(1, 2)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "less_than_equal" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=10, tuple_value=(1, 3)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "less_than_equal" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",)