Skip to content

Commit 0f081da

Browse files
airhornsclaude
andcommitted
Fix circuit breaker, nested concurrency, and add test coverage
- Circuit breaker: Use a shared threading.Event to cancel remaining audit tasks when the circuit breaker fires. Previously, CircuitBreakerError was collected like any other error and all tasks ran to completion. - Nested concurrency: Pass audit_concurrent_tasks=1 from the scheduler's flat pool to the evaluator, preventing max_workers * concurrent_tasks threads from hitting the DB simultaneously. Add audit_concurrent_tasks parameter to SnapshotEvaluator.audit() for this override. - Add tests for circuit breaker short-circuiting, blocking audit error collection (NodeAuditsErrors), and nested concurrency prevention. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0b56d4d commit 0f081da

File tree

3 files changed

+194
-6
lines changed

3 files changed

+194
-6
lines changed

sqlmesh/core/scheduler.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,14 +1014,14 @@ def _run_audits_concurrently(
10141014

10151015
errors: t.List[NodeExecutionFailedError[SchedulingUnit]] = []
10161016
errors_lock = threading.Lock()
1017+
cancelled = threading.Event()
10171018

10181019
def run_audit_task(node: EvaluateNode) -> None:
1019-
# The circuit breaker is checked at task start. Tasks already submitted to the
1020-
# thread pool will run to completion — unlike the DAG executor's level-by-level
1021-
# cancellation, this is acceptable for audit-only runs because audits are
1022-
# read-only and have no side effects.
1020+
if cancelled.is_set():
1021+
return
10231022
if circuit_breaker and circuit_breaker():
1024-
raise CircuitBreakerError()
1023+
cancelled.set()
1024+
return
10251025

10261026
snapshot = self.snapshots_by_name[node.snapshot_name]
10271027
node_start, node_end = node.interval
@@ -1035,6 +1035,7 @@ def _do_audit() -> t.List[AuditResult]:
10351035
start=node_start,
10361036
end=node_end,
10371037
execution_time=execution_time,
1038+
audit_concurrent_tasks=1,
10381039
)
10391040

10401041
self._run_node_with_progress(
@@ -1048,6 +1049,8 @@ def _do_audit() -> t.List[AuditResult]:
10481049
def run_audit_task_collecting_errors(node: EvaluateNode) -> None:
10491050
try:
10501051
run_audit_task(node)
1052+
except CircuitBreakerError:
1053+
cancelled.set()
10511054
except Exception as ex:
10521055
error: NodeExecutionFailedError[SchedulingUnit] = NodeExecutionFailedError(node)
10531056
error.__cause__ = ex
@@ -1056,6 +1059,9 @@ def run_audit_task_collecting_errors(node: EvaluateNode) -> None:
10561059

10571060
concurrent_apply_to_values(audit_tasks, run_audit_task_collecting_errors, self.max_workers)
10581061

1062+
if cancelled.is_set():
1063+
raise CircuitBreakerError()
1064+
10591065
return errors, []
10601066

10611067
def _check_ready_intervals(

sqlmesh/core/snapshot/evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def audit(
555555
execution_time: t.Optional[TimeLike] = None,
556556
deployability_index: t.Optional[DeployabilityIndex] = None,
557557
wap_id: t.Optional[str] = None,
558+
audit_concurrent_tasks: t.Optional[int] = None,
558559
**kwargs: t.Any,
559560
) -> t.List[AuditResult]:
560561
"""Execute a snapshot's node's audit queries.
@@ -632,10 +633,11 @@ def _run_audit(
632633
**kwargs,
633634
)
634635

636+
tasks_num = audit_concurrent_tasks if audit_concurrent_tasks is not None else self.concurrent_tasks
635637
results = concurrent_apply_to_values(
636638
prepared_audits,
637639
_run_audit,
638-
self.concurrent_tasks,
640+
tasks_num,
639641
)
640642

641643
if wap_id is not None:

tests/core/test_scheduler.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,3 +1488,183 @@ def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]:
14881488
assert all(tid != main_thread_id for tid in audit_call_thread_ids), (
14891489
"Both audits should run on worker threads regardless of DAG dependencies"
14901490
)
1491+
1492+
1493+
@pytest.mark.fast
1494+
def test_audit_only_circuit_breaker_stops_remaining_tasks(mocker: MockerFixture, make_snapshot):
1495+
"""When the circuit breaker fires, remaining audit tasks are skipped and CircuitBreakerError is raised."""
1496+
audit_calls: t.List[str] = []
1497+
audit_lock = threading.Lock()
1498+
1499+
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id")))
1500+
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id")))
1501+
snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("SELECT 3 as id")))
1502+
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
1503+
snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING)
1504+
snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING)
1505+
1506+
def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]:
1507+
with audit_lock:
1508+
audit_calls.append(snapshot.name)
1509+
return []
1510+
1511+
mock_evaluator = mocker.MagicMock()
1512+
mock_evaluator.audit.side_effect = fake_audit
1513+
mock_evaluator.get_snapshots_to_create.return_value = []
1514+
mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None)
1515+
mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False)
1516+
1517+
# Circuit breaker fires immediately on the first check
1518+
scheduler = Scheduler(
1519+
snapshots=[snapshot_a, snapshot_b, snapshot_c],
1520+
snapshot_evaluator=mock_evaluator,
1521+
state_sync=mocker.MagicMock(),
1522+
default_catalog=None,
1523+
max_workers=1, # Sequential so we can reason about ordering
1524+
)
1525+
1526+
interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))
1527+
merged_intervals: SnapshotToIntervals = {
1528+
snapshot_a: [interval],
1529+
snapshot_b: [interval],
1530+
snapshot_c: [interval],
1531+
}
1532+
1533+
with pytest.raises(CircuitBreakerError):
1534+
scheduler.run_merged_intervals(
1535+
merged_intervals=merged_intervals,
1536+
deployability_index=DeployabilityIndex.all_deployable(),
1537+
environment_naming_info=EnvironmentNamingInfo(),
1538+
audit_only=True,
1539+
circuit_breaker=lambda: True,
1540+
)
1541+
1542+
# With circuit breaker always-true, no audits should run
1543+
assert len(audit_calls) == 0
1544+
1545+
1546+
@pytest.mark.fast
1547+
def test_audit_only_blocking_audit_error_collected(mocker: MockerFixture, make_snapshot):
1548+
"""When a blocking audit fails (raises NodeAuditsErrors), the error is collected and other audits still run."""
1549+
audit_calls: t.List[str] = []
1550+
audit_lock = threading.Lock()
1551+
1552+
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id")))
1553+
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id")))
1554+
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
1555+
snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING)
1556+
1557+
def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]:
1558+
with audit_lock:
1559+
audit_calls.append(snapshot.name)
1560+
if snapshot.name == '"a"':
1561+
from sqlmesh.utils.errors import AuditError
1562+
from sqlglot import exp
1563+
1564+
audit_error = AuditError(
1565+
audit_name="not_null",
1566+
audit_args={},
1567+
model=snapshot.model_or_none,
1568+
count=5,
1569+
query=exp.select("1"),
1570+
adapter_dialect="duckdb",
1571+
)
1572+
raise NodeAuditsErrors([audit_error])
1573+
return []
1574+
1575+
mock_evaluator = mocker.MagicMock()
1576+
mock_evaluator.audit.side_effect = fake_audit
1577+
mock_evaluator.get_snapshots_to_create.return_value = []
1578+
mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None)
1579+
mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False)
1580+
1581+
mock_console = mocker.MagicMock()
1582+
1583+
scheduler = Scheduler(
1584+
snapshots=[snapshot_a, snapshot_b],
1585+
snapshot_evaluator=mock_evaluator,
1586+
state_sync=mocker.MagicMock(),
1587+
default_catalog=None,
1588+
max_workers=2,
1589+
console=mock_console,
1590+
)
1591+
1592+
interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))
1593+
merged_intervals: SnapshotToIntervals = {
1594+
snapshot_a: [interval],
1595+
snapshot_b: [interval],
1596+
}
1597+
1598+
errors, skipped = scheduler.run_merged_intervals(
1599+
merged_intervals=merged_intervals,
1600+
deployability_index=DeployabilityIndex.all_deployable(),
1601+
environment_naming_info=EnvironmentNamingInfo(),
1602+
audit_only=True,
1603+
)
1604+
1605+
# The NodeAuditsErrors should be collected as an error, not re-raised
1606+
assert len(errors) == 1
1607+
assert isinstance(errors[0].__cause__, NodeAuditsErrors)
1608+
assert skipped == []
1609+
# Both audits should have been attempted despite one failing
1610+
assert len(audit_calls) == 2
1611+
assert '"a"' in audit_calls
1612+
assert '"b"' in audit_calls
1613+
1614+
1615+
@pytest.mark.fast
1616+
def test_audit_only_no_nested_concurrency(mocker: MockerFixture, make_snapshot):
1617+
"""With scheduler max_workers > 1, each evaluator audit call uses sequential execution (audit_concurrent_tasks=1).
1618+
1619+
This prevents nested thread pool multiplication: max_workers * concurrent_tasks threads hitting
1620+
the DB at the same time.
1621+
"""
1622+
import sqlmesh.core.snapshot.evaluator as evaluator_module
1623+
1624+
spy = mocker.spy(evaluator_module, "concurrent_apply_to_values")
1625+
1626+
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id")))
1627+
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id")))
1628+
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
1629+
snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING)
1630+
1631+
mock_evaluator = mocker.MagicMock()
1632+
mock_evaluator.audit.return_value = []
1633+
mock_evaluator.get_snapshots_to_create.return_value = []
1634+
mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None)
1635+
mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False)
1636+
1637+
# Use the real SnapshotEvaluator to test the audit_concurrent_tasks parameter flows through
1638+
real_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=4)
1639+
real_evaluator.audit = mocker.MagicMock(return_value=[]) # type: ignore
1640+
1641+
scheduler = Scheduler(
1642+
snapshots=[snapshot_a, snapshot_b],
1643+
snapshot_evaluator=real_evaluator,
1644+
state_sync=mocker.MagicMock(),
1645+
default_catalog=None,
1646+
max_workers=2,
1647+
)
1648+
1649+
interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))
1650+
merged_intervals: SnapshotToIntervals = {
1651+
snapshot_a: [interval],
1652+
snapshot_b: [interval],
1653+
}
1654+
1655+
errors, skipped = scheduler.run_merged_intervals(
1656+
merged_intervals=merged_intervals,
1657+
deployability_index=DeployabilityIndex.all_deployable(),
1658+
environment_naming_info=EnvironmentNamingInfo(),
1659+
audit_only=True,
1660+
)
1661+
1662+
assert errors == []
1663+
assert skipped == []
1664+
assert real_evaluator.audit.call_count == 2
1665+
1666+
# Verify that audit_concurrent_tasks=1 was passed to each audit call to prevent nested pools
1667+
for call in real_evaluator.audit.call_args_list:
1668+
assert call.kwargs.get("audit_concurrent_tasks") == 1, (
1669+
"audit_concurrent_tasks=1 must be passed to prevent nested thread pool multiplication"
1670+
)

0 commit comments

Comments
 (0)