From e36e2ea5023266f424e0484046d712fef19307e0 Mon Sep 17 00:00:00 2001 From: samini Date: Wed, 11 Mar 2026 17:27:30 +0000 Subject: [PATCH 1/4] Add generics tests --- tests/test_generics.py | 118 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 tests/test_generics.py diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000000..2950377f6e --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,118 @@ +from enum import Enum +from typing import Generic, Literal, TypeVar + +import pytest +from sqlalchemy import create_engine +from sqlmodel import Field, Session, SQLModel, select +from typing_extensions import assert_type + + +def test_generic_type_with_bound(clear_sqlmodel) -> None: + TagT = TypeVar("TagT", bound=int) + + class HeroFields(SQLModel, Generic[TagT]): + tag: TagT + + class Hero(HeroFields[int], table=True): + id: int | None = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + tag_number = 67 + hero = Hero(tag=tag_number) + session.add(hero) + + hero = session.exec(select(Hero).where(Hero.tag == tag_number)).first() + assert hero is not None + assert hero.tag == tag_number + + +def test_generic_type_with_constraints(clear_sqlmodel) -> None: + TagT = TypeVar("TagT", int, None) + + class HeroFields(SQLModel, Generic[TagT]): + tag: TagT + + class Hero(HeroFields[int], table=True): + id: int | None = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + tag_number = 67 + hero = Hero(tag=tag_number) + session.add(hero) + + hero = session.exec(select(Hero).where(Hero.tag == tag_number)).first() + assert hero is not None + assert hero.tag == tag_number + + +def test_generic_type_with_multiple_type_constraints_raises_error( + clear_sqlmodel, +) -> None: + with pytest.raises(ValueError): + TagT = TypeVar("TagT", int, str) + + class HeroFields(SQLModel, Generic[TagT]): + tag: TagT + + class Hero(HeroFields[int], table=True): + id: int | None = Field(default=None, primary_key=True) + + +def test_discriminated_union_with_generics(clear_sqlmodel) -> None: + AmountRefundedT = TypeVar("AmountRefundedT", bound=int | None) + RejectionMessageT = TypeVar("RejectionMessageT", bound=str | None) + + class RefundStatus(str, Enum): + ACCEPTED = "ACCEPTED" + REJECTED = "REJECTED" + + DiscriminantT = TypeVar("DiscriminantT", bound=RefundStatus) + + class RefundRequestFields(SQLModel, Generic[AmountRefundedT, RejectionMessageT, DiscriminantT]): + item_name: str + amount_refunded: AmountRefundedT + rejection_message: RejectionMessageT + status: DiscriminantT + + class RefundRequest(RefundRequestFields[int | None, str | None, RefundStatus], table=True): + id: int | None = Field(default=None, primary_key=True) + status: RefundStatus + + class AcceptedRequest(RefundRequestFields[int, None, RefundStatus.ACCEPTED]): + amount_refunded: int + rejection_message: None = None + status: Literal[RefundStatus.ACCEPTED] = RefundStatus.ACCEPTED + + class RejectedRequest(RefundRequestFields[None, str, RefundStatus.REJECTED]): + rejection_message: str + amount_refunded: None = None + status: Literal[RefundStatus.REJECTED] = RefundStatus.REJECTED + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + c = RejectedRequest( + item_name="EmptyJuice", + rejection_message="This item cannot be refunded because it has been emptied", + ) + session.add(RefundRequest.model_validate(c.model_dump())) + + requests = session.exec( + select(RefundRequest).where( + RefundRequest.status == RefundStatus.REJECTED, + ) + ).all() + rejected_requests = [ + RejectedRequest.model_validate(request.model_dump()) + for request in requests + if request.status == RefundStatus.REJECTED + ] + assert_type(rejected_requests, list[RejectedRequest]) + assert len(rejected_requests) == 1 From 42e6e10f65980bdc2f5b5cf60642cb3d9e9ecea1 Mon Sep 17 00:00:00 2001 From: samini Date: Wed, 11 Mar 2026 17:27:36 +0000 Subject: [PATCH 2/4] Add generics support --- sqlmodel/main.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..c455155ae4 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -738,7 +738,24 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") +def _create_union(args: tuple[Any, ...]) -> Any | Any: + if len(args) == 1: + return args[0] + return args[0] | _create_union(args[1:]) + + def get_column_from_field(field: Any) -> Column: # type: ignore + if isinstance(field.annotation, TypeVar): + generic: TypeVar = field.annotation + if generic.__bound__ is not None: + field.annotation = generic.__bound__ + elif generic.__constraints__ != (): + constraints = generic.__constraints__ + field.annotation = _create_union(constraints) + else: + raise TypeError( + f"Invalid type used for {field}. Please define a bound or constraints." + ) field_info = field sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined) if isinstance(sa_column, Column): From df6d70a3eac827fc7d63ae67fc33b17c1947f166 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 17:35:00 +0000 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=8E=A8=20Auto=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_generics.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_generics.py b/tests/test_generics.py index 2950377f6e..9ea22dc322 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -74,13 +74,17 @@ class RefundStatus(str, Enum): DiscriminantT = TypeVar("DiscriminantT", bound=RefundStatus) - class RefundRequestFields(SQLModel, Generic[AmountRefundedT, RejectionMessageT, DiscriminantT]): + class RefundRequestFields( + SQLModel, Generic[AmountRefundedT, RejectionMessageT, DiscriminantT] + ): item_name: str amount_refunded: AmountRefundedT rejection_message: RejectionMessageT status: DiscriminantT - class RefundRequest(RefundRequestFields[int | None, str | None, RefundStatus], table=True): + class RefundRequest( + RefundRequestFields[int | None, str | None, RefundStatus], table=True + ): id: int | None = Field(default=None, primary_key=True) status: RefundStatus From 59a3d3fb9408d8edc7ae29ba1128be6814ef81a0 Mon Sep 17 00:00:00 2001 From: samini Date: Wed, 11 Mar 2026 19:15:26 +0000 Subject: [PATCH 4/4] Fix return type --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index c455155ae4..c26a397eb9 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -738,7 +738,7 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def _create_union(args: tuple[Any, ...]) -> Any | Any: +def _create_union(args: tuple[Any, ...]) -> Any: if len(args) == 1: return args[0] return args[0] | _create_union(args[1:])