Skip to content
Open
57 changes: 39 additions & 18 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
ClassVar,
Literal,
TypeAlias,
TypedDict,
TypeVar,
Union,
cast,
get_origin,
overload,
)

from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, EmailStr, create_model
from pydantic.fields import FieldInfo as PydanticFieldInfo
from sqlalchemy import (
Boolean,
Expand All @@ -49,7 +50,7 @@
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import deprecated
from typing_extensions import Unpack, deprecated

from ._compat import (
PYDANTIC_MINOR_VERSION,
Expand Down Expand Up @@ -801,6 +802,22 @@ def get_column_from_field(field: Any) -> Column:
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")


class _ModelDumpKwargs(TypedDict):
mode: Literal["json", "python"] | str
include: IncEx | None
exclude: IncEx | None
context: Any | None # v2.7
by_alias: bool | None
exclude_unset: bool
exclude_defaults: bool
exclude_none: bool
exclude_computed_fields: bool # v2.12
round_trip: bool
warnings: bool | Literal["none", "warn", "error"]
fallback: Callable[[Any], Any] | None # v2.11
serialize_as_any: bool # v2.7


class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
Expand Down Expand Up @@ -984,25 +1001,29 @@ def sqlmodel_update(
obj: builtins.dict[str, Any] | BaseModel,
*,
update: builtins.dict[str, Any] | None = None,
**model_dump_kwargs: Unpack[_ModelDumpKwargs],
) -> _TSQLModel:
use_update = (update or {}).copy()
if isinstance(obj, dict):
for key, value in {**obj, **use_update}.items():
if key in get_model_fields(self):
setattr(self, key, value)
elif isinstance(obj, BaseModel):
for key in get_model_fields(obj):
if key in use_update:
value = use_update.pop(key)
else:
value = getattr(obj, key)
setattr(self, key, value)
for remaining_key, value in use_update.items():
if remaining_key in get_model_fields(self):
setattr(self, remaining_key, value)
else:
if not (isinstance(obj, dict) or isinstance(obj, BaseModel)):
raise ValueError(
"Can't use sqlmodel_update() with something that "
f"is not a dict or SQLModel or Pydantic model: {obj}"
)
if isinstance(obj, BaseModel):
# Create a temp UpdateModel schema (removes extra serialization settings)
ObjClass = obj.__class__
fields_def = {
fname: finfo.annotation
for fname, finfo in ObjClass.model_fields.items()
}
UpdateModel = create_model(f"_{ObjClass.__name__}Update_", **fields_def)
# rebuild obj instance with model_construct
obj = UpdateModel.model_construct(
_fields_set=obj.model_fields_set, **obj.__dict__
)
# Now `obj.model_dump` works with **model_dump_kwargs
obj = obj.model_dump(**model_dump_kwargs)
Comment on lines +1011 to +1024
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break use cases with models that have fields with exclude=True:

from sqlmodel import Field, SQLModel


class Item(SQLModel):
    id: str
    param: str = Field(exclude=True)


a = Item.model_validate({"id": "1", "param": "1"})
b = Item.model_validate({"id": "1", "param": "2"})


a.sqlmodel_update(b, exclude={"id"})
# a.sqlmodel_update(b)


assert a.param == "2"

.. and probably some other cases when model has settings that change the default serialization schema.

use_update = (update or {}).copy()
for key, value in {**obj, **use_update}.items():
if key in get_model_fields(self):
setattr(self, key, value)
return self
14 changes: 13 additions & 1 deletion tests/test_update.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from pytest import raises
from sqlmodel import Field, SQLModel


def test_sqlmodel_update():
class Organization(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
name: str
city: str
headquarters: str

class OrganizationUpdate(SQLModel):
name: str
name: str = Field(exclude=True)
city: str | None = None

org = Organization(name="Example Org", city="New York", headquarters="NYC HQ")
org_in = OrganizationUpdate(name="Updated org")
Expand All @@ -17,4 +20,13 @@ class OrganizationUpdate(SQLModel):
update={
"headquarters": "-", # This field is in Organization, but not in OrganizationUpdate
},
exclude_unset=True,
)
# fields that should stay the same
assert org.city == "New York"
# fields that should be updated
assert org.name == "Updated org"
assert org.headquarters == "-"
# test raise value error when passing in updates other than dict or BaseModel
with raises(ValueError):
org.sqlmodel_update(["Boston"])
Loading