@@ -1488,3 +1488,183 @@ def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]:
14881488 assert all (tid != main_thread_id for tid in audit_call_thread_ids ), (
14891489 "Both audits should run on worker threads regardless of DAG dependencies"
14901490 )
1491+
1492+
1493+ @pytest .mark .fast
1494+ def test_audit_only_circuit_breaker_stops_remaining_tasks (mocker : MockerFixture , make_snapshot ):
1495+ """When the circuit breaker fires, remaining audit tasks are skipped and CircuitBreakerError is raised."""
1496+ audit_calls : t .List [str ] = []
1497+ audit_lock = threading .Lock ()
1498+
1499+ snapshot_a = make_snapshot (SqlModel (name = "a" , query = parse_one ("SELECT 1 as id" )))
1500+ snapshot_b = make_snapshot (SqlModel (name = "b" , query = parse_one ("SELECT 2 as id" )))
1501+ snapshot_c = make_snapshot (SqlModel (name = "c" , query = parse_one ("SELECT 3 as id" )))
1502+ snapshot_a .categorize_as (SnapshotChangeCategory .BREAKING )
1503+ snapshot_b .categorize_as (SnapshotChangeCategory .BREAKING )
1504+ snapshot_c .categorize_as (SnapshotChangeCategory .BREAKING )
1505+
1506+ def fake_audit (snapshot : Snapshot , ** kwargs : t .Any ) -> t .List [AuditResult ]:
1507+ with audit_lock :
1508+ audit_calls .append (snapshot .name )
1509+ return []
1510+
1511+ mock_evaluator = mocker .MagicMock ()
1512+ mock_evaluator .audit .side_effect = fake_audit
1513+ mock_evaluator .get_snapshots_to_create .return_value = []
1514+ mock_evaluator .concurrent_context .return_value .__enter__ = mocker .Mock (return_value = None )
1515+ mock_evaluator .concurrent_context .return_value .__exit__ = mocker .Mock (return_value = False )
1516+
1517+ # Circuit breaker fires immediately on the first check
1518+ scheduler = Scheduler (
1519+ snapshots = [snapshot_a , snapshot_b , snapshot_c ],
1520+ snapshot_evaluator = mock_evaluator ,
1521+ state_sync = mocker .MagicMock (),
1522+ default_catalog = None ,
1523+ max_workers = 1 , # Sequential so we can reason about ordering
1524+ )
1525+
1526+ interval = (to_timestamp ("2023-01-01" ), to_timestamp ("2023-01-02" ))
1527+ merged_intervals : SnapshotToIntervals = {
1528+ snapshot_a : [interval ],
1529+ snapshot_b : [interval ],
1530+ snapshot_c : [interval ],
1531+ }
1532+
1533+ with pytest .raises (CircuitBreakerError ):
1534+ scheduler .run_merged_intervals (
1535+ merged_intervals = merged_intervals ,
1536+ deployability_index = DeployabilityIndex .all_deployable (),
1537+ environment_naming_info = EnvironmentNamingInfo (),
1538+ audit_only = True ,
1539+ circuit_breaker = lambda : True ,
1540+ )
1541+
1542+ # With circuit breaker always-true, no audits should run
1543+ assert len (audit_calls ) == 0
1544+
1545+
1546+ @pytest .mark .fast
1547+ def test_audit_only_blocking_audit_error_collected (mocker : MockerFixture , make_snapshot ):
1548+ """When a blocking audit fails (raises NodeAuditsErrors), the error is collected and other audits still run."""
1549+ audit_calls : t .List [str ] = []
1550+ audit_lock = threading .Lock ()
1551+
1552+ snapshot_a = make_snapshot (SqlModel (name = "a" , query = parse_one ("SELECT 1 as id" )))
1553+ snapshot_b = make_snapshot (SqlModel (name = "b" , query = parse_one ("SELECT 2 as id" )))
1554+ snapshot_a .categorize_as (SnapshotChangeCategory .BREAKING )
1555+ snapshot_b .categorize_as (SnapshotChangeCategory .BREAKING )
1556+
1557+ def fake_audit (snapshot : Snapshot , ** kwargs : t .Any ) -> t .List [AuditResult ]:
1558+ with audit_lock :
1559+ audit_calls .append (snapshot .name )
1560+ if snapshot .name == '"a"' :
1561+ from sqlmesh .utils .errors import AuditError
1562+ from sqlglot import exp
1563+
1564+ audit_error = AuditError (
1565+ audit_name = "not_null" ,
1566+ audit_args = {},
1567+ model = snapshot .model_or_none ,
1568+ count = 5 ,
1569+ query = exp .select ("1" ),
1570+ adapter_dialect = "duckdb" ,
1571+ )
1572+ raise NodeAuditsErrors ([audit_error ])
1573+ return []
1574+
1575+ mock_evaluator = mocker .MagicMock ()
1576+ mock_evaluator .audit .side_effect = fake_audit
1577+ mock_evaluator .get_snapshots_to_create .return_value = []
1578+ mock_evaluator .concurrent_context .return_value .__enter__ = mocker .Mock (return_value = None )
1579+ mock_evaluator .concurrent_context .return_value .__exit__ = mocker .Mock (return_value = False )
1580+
1581+ mock_console = mocker .MagicMock ()
1582+
1583+ scheduler = Scheduler (
1584+ snapshots = [snapshot_a , snapshot_b ],
1585+ snapshot_evaluator = mock_evaluator ,
1586+ state_sync = mocker .MagicMock (),
1587+ default_catalog = None ,
1588+ max_workers = 2 ,
1589+ console = mock_console ,
1590+ )
1591+
1592+ interval = (to_timestamp ("2023-01-01" ), to_timestamp ("2023-01-02" ))
1593+ merged_intervals : SnapshotToIntervals = {
1594+ snapshot_a : [interval ],
1595+ snapshot_b : [interval ],
1596+ }
1597+
1598+ errors , skipped = scheduler .run_merged_intervals (
1599+ merged_intervals = merged_intervals ,
1600+ deployability_index = DeployabilityIndex .all_deployable (),
1601+ environment_naming_info = EnvironmentNamingInfo (),
1602+ audit_only = True ,
1603+ )
1604+
1605+ # The NodeAuditsErrors should be collected as an error, not re-raised
1606+ assert len (errors ) == 1
1607+ assert isinstance (errors [0 ].__cause__ , NodeAuditsErrors )
1608+ assert skipped == []
1609+ # Both audits should have been attempted despite one failing
1610+ assert len (audit_calls ) == 2
1611+ assert '"a"' in audit_calls
1612+ assert '"b"' in audit_calls
1613+
1614+
1615+ @pytest .mark .fast
1616+ def test_audit_only_no_nested_concurrency (mocker : MockerFixture , make_snapshot ):
1617+ """With scheduler max_workers > 1, each evaluator audit call uses sequential execution (audit_concurrent_tasks=1).
1618+
1619+ This prevents nested thread pool multiplication: max_workers * concurrent_tasks threads hitting
1620+ the DB at the same time.
1621+ """
1622+ import sqlmesh .core .snapshot .evaluator as evaluator_module
1623+
1624+ spy = mocker .spy (evaluator_module , "concurrent_apply_to_values" )
1625+
1626+ snapshot_a = make_snapshot (SqlModel (name = "a" , query = parse_one ("SELECT 1 as id" )))
1627+ snapshot_b = make_snapshot (SqlModel (name = "b" , query = parse_one ("SELECT 2 as id" )))
1628+ snapshot_a .categorize_as (SnapshotChangeCategory .BREAKING )
1629+ snapshot_b .categorize_as (SnapshotChangeCategory .BREAKING )
1630+
1631+ mock_evaluator = mocker .MagicMock ()
1632+ mock_evaluator .audit .return_value = []
1633+ mock_evaluator .get_snapshots_to_create .return_value = []
1634+ mock_evaluator .concurrent_context .return_value .__enter__ = mocker .Mock (return_value = None )
1635+ mock_evaluator .concurrent_context .return_value .__exit__ = mocker .Mock (return_value = False )
1636+
1637+ # Use the real SnapshotEvaluator to test the audit_concurrent_tasks parameter flows through
1638+ real_evaluator = SnapshotEvaluator (adapters = mocker .MagicMock (), concurrent_tasks = 4 )
1639+ real_evaluator .audit = mocker .MagicMock (return_value = []) # type: ignore
1640+
1641+ scheduler = Scheduler (
1642+ snapshots = [snapshot_a , snapshot_b ],
1643+ snapshot_evaluator = real_evaluator ,
1644+ state_sync = mocker .MagicMock (),
1645+ default_catalog = None ,
1646+ max_workers = 2 ,
1647+ )
1648+
1649+ interval = (to_timestamp ("2023-01-01" ), to_timestamp ("2023-01-02" ))
1650+ merged_intervals : SnapshotToIntervals = {
1651+ snapshot_a : [interval ],
1652+ snapshot_b : [interval ],
1653+ }
1654+
1655+ errors , skipped = scheduler .run_merged_intervals (
1656+ merged_intervals = merged_intervals ,
1657+ deployability_index = DeployabilityIndex .all_deployable (),
1658+ environment_naming_info = EnvironmentNamingInfo (),
1659+ audit_only = True ,
1660+ )
1661+
1662+ assert errors == []
1663+ assert skipped == []
1664+ assert real_evaluator .audit .call_count == 2
1665+
1666+ # Verify that audit_concurrent_tasks=1 was passed to each audit call to prevent nested pools
1667+ for call in real_evaluator .audit .call_args_list :
1668+ assert call .kwargs .get ("audit_concurrent_tasks" ) == 1 , (
1669+ "audit_concurrent_tasks=1 must be passed to prevent nested thread pool multiplication"
1670+ )
0 commit comments