Skip to content

Commit 0621ae2

Browse files
committed
Add fail_fast param to Field, add tests
1 parent b612f28 commit 0621ae2

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

sqlmodel/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def Field(
229229
min_length: Optional[int] = None,
230230
max_length: Optional[int] = None,
231231
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
232+
fail_fast: Optional[bool] = None,
232233
allow_mutation: bool = True,
233234
regex: Optional[str] = None,
234235
discriminator: Optional[str] = None,
@@ -275,6 +276,7 @@ def Field(
275276
min_length: Optional[int] = None,
276277
max_length: Optional[int] = None,
277278
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
279+
fail_fast: Optional[bool] = None,
278280
allow_mutation: bool = True,
279281
regex: Optional[str] = None,
280282
discriminator: Optional[str] = None,
@@ -330,6 +332,7 @@ def Field(
330332
min_length: Optional[int] = None,
331333
max_length: Optional[int] = None,
332334
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
335+
fail_fast: Optional[bool] = None,
333336
allow_mutation: bool = True,
334337
regex: Optional[str] = None,
335338
discriminator: Optional[str] = None,
@@ -366,6 +369,7 @@ def Field(
366369
min_length: Optional[int] = None,
367370
max_length: Optional[int] = None,
368371
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
372+
fail_fast: Optional[bool] = None,
369373
allow_mutation: bool = True,
370374
regex: Optional[str] = None,
371375
discriminator: Optional[str] = None,
@@ -389,6 +393,7 @@ def Field(
389393
"coerce_numbers_to_str",
390394
"validate_default",
391395
"union_mode",
396+
"fail_fast",
392397
):
393398
if param_name in current_schema_extra:
394399
msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
@@ -403,6 +408,7 @@ def Field(
403408
current_validate_default = validate_default or current_schema_extra.pop(
404409
"validate_default", None
405410
)
411+
current_fail_fast = fail_fast or current_schema_extra.pop("fail_fast", None)
406412
field_info_kwargs = {
407413
"alias": alias,
408414
"title": title,
@@ -424,6 +430,7 @@ def Field(
424430
"unique_items": unique_items,
425431
"min_length": min_length,
426432
"max_length": max_length,
433+
"fail_fast": current_fail_fast,
427434
"allow_mutation": allow_mutation,
428435
"regex": regex,
429436
"discriminator": discriminator,

tests/test_pydantic/test_field.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,51 @@ class Model(SQLModel):
195195

196196
c = Model.model_validate({"val": 123.1})
197197
assert isinstance(c.val, float)
198+
199+
200+
def test_fail_fast_true():
201+
class Model(SQLModel):
202+
val: list[int] = Field(fail_fast=True)
203+
204+
with pytest.raises(ValidationError) as exc_info:
205+
Model.model_validate({"val": [1.1, "not an int"]})
206+
207+
errors = exc_info.value.errors()
208+
assert len(errors) == 1
209+
assert errors[0]["type"] == "int_from_float"
210+
211+
212+
@pytest.mark.parametrize("fail_fast", [None, False])
213+
def test_fail_fast_false(fail_fast: Optional[bool]):
214+
class Model(SQLModel):
215+
val: list[int] = Field(fail_fast=fail_fast)
216+
217+
with pytest.raises(ValidationError) as exc_info:
218+
Model.model_validate({"val": [1.1, "not an int"]})
219+
220+
errors = exc_info.value.errors()
221+
assert len(errors) == 2
222+
error_types = {error["type"] for error in errors}
223+
224+
assert "int_from_float" in error_types
225+
assert "int_parsing" in error_types
226+
227+
228+
def test_fail_fast_via_schema_extra(): # Current workaround. Remove after some time
229+
with pytest.warns(
230+
UserWarning,
231+
match=(
232+
"Pass `fail_fast` parameter directly to Field instead of passing "
233+
"it via `schema_extra`"
234+
),
235+
):
236+
237+
class Model(SQLModel):
238+
val: list[int] = Field(schema_extra={"fail_fast": True})
239+
240+
with pytest.raises(ValidationError) as exc_info:
241+
Model.model_validate({"val": [1.1, "not an int"]})
242+
243+
errors = exc_info.value.errors()
244+
assert len(errors) == 1
245+
assert errors[0]["type"] == "int_from_float"

0 commit comments

Comments
 (0)