diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9a1a676775..33023eec74 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -15,6 +15,7 @@ ClassVar, Literal, TypeAlias, + TypedDict, TypeVar, Union, cast, @@ -22,7 +23,7 @@ overload, ) -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, EmailStr, create_model from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( Boolean, @@ -49,7 +50,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import deprecated +from typing_extensions import Unpack, deprecated from ._compat import ( PYDANTIC_MINOR_VERSION, @@ -801,6 +802,22 @@ def get_column_from_field(field: Any) -> Column: _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") +class _ModelDumpKwargs(TypedDict): + mode: Literal["json", "python"] | str + include: IncEx | None + exclude: IncEx | None + context: Any | None # v2.7 + by_alias: bool | None + exclude_unset: bool + exclude_defaults: bool + exclude_none: bool + exclude_computed_fields: bool # v2.12 + round_trip: bool + warnings: bool | Literal["none", "warn", "error"] + fallback: Callable[[Any], Any] | None # v2.11 + serialize_as_any: bool # v2.7 + + class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) @@ -984,25 +1001,29 @@ def sqlmodel_update( obj: builtins.dict[str, Any] | BaseModel, *, update: builtins.dict[str, Any] | None = None, + **model_dump_kwargs: Unpack[_ModelDumpKwargs], ) -> _TSQLModel: - use_update = (update or {}).copy() - if isinstance(obj, dict): - for key, value in {**obj, **use_update}.items(): - if key in get_model_fields(self): - setattr(self, key, value) - elif isinstance(obj, BaseModel): - for key in get_model_fields(obj): - if key in use_update: - value = use_update.pop(key) - else: - value = getattr(obj, key) - setattr(self, key, value) - for remaining_key, value in use_update.items(): - if remaining_key in get_model_fields(self): - setattr(self, remaining_key, value) - else: + if not (isinstance(obj, dict) or isinstance(obj, BaseModel)): raise ValueError( "Can't use sqlmodel_update() with something that " f"is not a dict or SQLModel or Pydantic model: {obj}" ) + if isinstance(obj, BaseModel): + # Create a temp UpdateModel schema (removes extra serialization settings) + ObjClass = obj.__class__ + fields_def = { + fname: finfo.annotation + for fname, finfo in ObjClass.model_fields.items() + } + UpdateModel = create_model(f"_{ObjClass.__name__}Update_", **fields_def) + # rebuild obj instance with model_construct + obj = UpdateModel.model_construct( + _fields_set=obj.model_fields_set, **obj.__dict__ + ) + # Now `obj.model_dump` works with **model_dump_kwargs + obj = obj.model_dump(**model_dump_kwargs) + use_update = (update or {}).copy() + for key, value in {**obj, **use_update}.items(): + if key in get_model_fields(self): + setattr(self, key, value) return self diff --git a/tests/test_update.py b/tests/test_update.py index de4bd6cdd2..0d6beffbb3 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,3 +1,4 @@ +from pytest import raises from sqlmodel import Field, SQLModel @@ -5,10 +6,12 @@ def test_sqlmodel_update(): class Organization(SQLModel, table=True): id: int = Field(default=None, primary_key=True) name: str + city: str headquarters: str class OrganizationUpdate(SQLModel): - name: str + name: str = Field(exclude=True) + city: str | None = None org = Organization(name="Example Org", city="New York", headquarters="NYC HQ") org_in = OrganizationUpdate(name="Updated org") @@ -17,4 +20,13 @@ class OrganizationUpdate(SQLModel): update={ "headquarters": "-", # This field is in Organization, but not in OrganizationUpdate }, + exclude_unset=True, ) + # fields that should stay the same + assert org.city == "New York" + # fields that should be updated + assert org.name == "Updated org" + assert org.headquarters == "-" + # test raise value error when passing in updates other than dict or BaseModel + with raises(ValueError): + org.sqlmodel_update(["Boston"])