Skip to content

Commit b612f28

Browse files
committed
Add union_mode param to Field, add tests
1 parent 25232d1 commit b612f28

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

sqlmodel/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def Field(
228228
unique_items: Optional[bool] = None,
229229
min_length: Optional[int] = None,
230230
max_length: Optional[int] = None,
231+
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
231232
allow_mutation: bool = True,
232233
regex: Optional[str] = None,
233234
discriminator: Optional[str] = None,
@@ -273,6 +274,7 @@ def Field(
273274
unique_items: Optional[bool] = None,
274275
min_length: Optional[int] = None,
275276
max_length: Optional[int] = None,
277+
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
276278
allow_mutation: bool = True,
277279
regex: Optional[str] = None,
278280
discriminator: Optional[str] = None,
@@ -327,6 +329,7 @@ def Field(
327329
unique_items: Optional[bool] = None,
328330
min_length: Optional[int] = None,
329331
max_length: Optional[int] = None,
332+
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
330333
allow_mutation: bool = True,
331334
regex: Optional[str] = None,
332335
discriminator: Optional[str] = None,
@@ -362,6 +365,7 @@ def Field(
362365
unique_items: Optional[bool] = None,
363366
min_length: Optional[int] = None,
364367
max_length: Optional[int] = None,
368+
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
365369
allow_mutation: bool = True,
366370
regex: Optional[str] = None,
367371
discriminator: Optional[str] = None,
@@ -384,6 +388,7 @@ def Field(
384388
for param_name in (
385389
"coerce_numbers_to_str",
386390
"validate_default",
391+
"union_mode",
387392
):
388393
if param_name in current_schema_extra:
389394
msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
@@ -444,6 +449,10 @@ def Field(
444449
serialization_alias or schema_serialization_alias or alias
445450
)
446451

452+
current_union_mode = union_mode or current_schema_extra.pop("union_mode", None)
453+
if current_union_mode is not None:
454+
field_info_kwargs["union_mode"] = current_union_mode
455+
447456
field_info = FieldInfo(
448457
default,
449458
default_factory=default_factory,

tests/test_pydantic/test_field.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,54 @@ class Model(SQLModel):
144144
val: int = Field(default="123", schema_extra={"validate_default": True})
145145

146146
assert Model.model_validate({}).val == 123
147+
148+
149+
@pytest.mark.parametrize("union_mode", [None, "smart"])
150+
def test_union_mode_smart(union_mode: Optional[Literal["smart"]]):
151+
class Model(SQLModel):
152+
val: Union[float, int] = Field(union_mode=union_mode)
153+
154+
a = Model.model_validate({"val": 123})
155+
assert isinstance(a.val, int) # float is first, but int is more precise
156+
157+
b = Model.model_validate({"val": 123.0})
158+
assert isinstance(b.val, float)
159+
160+
c = Model.model_validate({"val": 123.1})
161+
assert isinstance(c.val, float)
162+
163+
164+
def test_union_mode_left_to_right():
165+
class Model(SQLModel):
166+
val: Union[float, int] = Field(union_mode="left_to_right")
167+
168+
a = Model.model_validate({"val": 123})
169+
assert isinstance(a.val, float)
170+
171+
b = Model.model_validate({"val": 123.0})
172+
assert isinstance(b.val, float)
173+
174+
c = Model.model_validate({"val": 123.1})
175+
assert isinstance(c.val, float)
176+
177+
178+
def test_union_mode_via_schema_extra(): # Current workaround. Remove after some time
179+
with pytest.warns(
180+
UserWarning,
181+
match=(
182+
"Pass `union_mode` parameter directly to Field instead of passing "
183+
"it via `schema_extra`"
184+
),
185+
):
186+
187+
class Model(SQLModel):
188+
val: Union[float, int] = Field(schema_extra={"union_mode": "smart"})
189+
190+
a = Model.model_validate({"val": 123})
191+
assert isinstance(a.val, int) # float is first, but int is more precise
192+
193+
b = Model.model_validate({"val": 123.0})
194+
assert isinstance(b.val, float)
195+
196+
c = Model.model_validate({"val": 123.1})
197+
assert isinstance(c.val, float)

0 commit comments

Comments
 (0)