diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 003c5bdad..a7c50ea44 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -28,12 +28,12 @@ from data_designer.config.utils.type_helpers import StrEnum from data_designer.config.utils.warning_helpers import warn_at_caller from data_designer.config.version import get_library_version +from data_designer.engine import flags from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, ColumnGeneratorWithModel, GenerationStrategy, ) -from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated from data_designer.engine.compiler import compile_data_designer_config from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig @@ -55,6 +55,7 @@ from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler from data_designer.engine.processing.processors.base import Processor from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor +from data_designer.engine.readiness import run_readiness_check from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider from data_designer.engine.storage.artifact_storage import ( @@ -75,12 +76,12 @@ logger = logging.getLogger(__name__) -# Async engine is the default execution path. Set ``DATA_DESIGNER_ASYNC_ENGINE=0`` -# to opt back into the legacy sync engine for one transitional release; the sync -# path is scheduled for removal afterwards. -DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "1") == "1" +# The async-engine flag now lives in ``data_designer.engine.flags`` so the +# engine, the public interface, and the readiness module can share one source +# of truth. Always read ``flags.DATA_DESIGNER_ASYNC_ENGINE`` rather than caching +# a local copy so monkeypatches in tests are visible. -if DATA_DESIGNER_ASYNC_ENGINE: +if flags.DATA_DESIGNER_ASYNC_ENGINE: import asyncio from data_designer.engine.dataset_builders.async_scheduler import ( @@ -133,7 +134,7 @@ def __init__( self._task_traces: list[TaskTrace] = [] self._registry = registry or DataDesignerRegistry() self._graph: ExecutionGraph | None = None - self._use_async: bool = DATA_DESIGNER_ASYNC_ENGINE + self._use_async: bool = flags.DATA_DESIGNER_ASYNC_ENGINE # Structured signal: set by _build_async if the scheduler hit early shutdown. # Stays at defaults for sync-engine and successful async runs. Reset at # the start of each public run path so reused builder instances don't @@ -215,10 +216,6 @@ def single_column_configs(self) -> list[ColumnConfigT]: def single_column_config_by_name(self) -> dict[str, ColumnConfigT]: return {config.name: config for config in self.single_column_configs} - @functools.cached_property - def llm_generated_column_configs(self) -> list[ColumnConfigT]: - return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)] - def build( self, *, @@ -255,8 +252,7 @@ def build( """ self._reset_run_state() - self._run_model_health_check_if_needed() - self._run_mcp_tool_check_if_needed() + run_readiness_check(self.single_column_configs, self._resource_provider) # For IF_POSSIBLE and ALWAYS: check config compatibility before touching the artifact # directory. _check_resume_config_compatibility() must NOT access base_dataset_path @@ -326,7 +322,7 @@ def build( "start a new generation run." ) - self._use_async = DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility() + self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility() if self._use_async: self._build_async(generators, num_records, buffer_size, on_batch_complete, resume=resume) elif resume == ResumeMode.ALWAYS: @@ -589,8 +585,7 @@ def _build_with_resume( def build_preview(self, *, num_records: int) -> pd.DataFrame: self._reset_run_state() - self._run_model_health_check_if_needed() - self._run_mcp_tool_check_if_needed() + run_readiness_check(self.single_column_configs, self._resource_provider) # Set media storage to DATAFRAME mode for preview - base64 stored directly in DataFrame if self._has_image_columns(): @@ -599,7 +594,7 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: generators, self._graph = self._initialize_generators_and_graph() start_time = time.perf_counter() - self._use_async = DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility() + self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility() if self._use_async: dataset = self._build_async_preview(generators, num_records) else: @@ -1327,38 +1322,6 @@ def _merge_skipped_and_generated( batch.append(gen_result) return batch - def _run_model_health_check_if_needed(self) -> None: - model_aliases: set[str] = set() - for config in self.single_column_configs: - model_aliases.update(config.get_model_aliases()) - - if not model_aliases: - return - - if DATA_DESIGNER_ASYNC_ENGINE: - loop = ensure_async_engine_loop() - future = asyncio.run_coroutine_threadsafe( - self._resource_provider.model_registry.arun_health_check(list(model_aliases)), - loop, - ) - try: - future.result(timeout=180) - except TimeoutError: - future.cancel() - raise - else: - self._resource_provider.model_registry.run_health_check(list(model_aliases)) - - def _run_mcp_tool_check_if_needed(self) -> None: - tool_aliases = sorted( - {config.tool_alias for config in self.llm_generated_column_configs if getattr(config, "tool_alias", None)} - ) - if not tool_aliases: - return - if self._resource_provider.mcp_registry is None: - raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.") - self._resource_provider.mcp_registry.run_health_check(tool_aliases) - def _setup_fan_out( self, generator: ColumnGeneratorWithModelRegistry, diff --git a/packages/data-designer-engine/src/data_designer/engine/flags.py b/packages/data-designer-engine/src/data_designer/engine/flags.py new file mode 100644 index 000000000..b1e9f9f00 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/flags.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Engine-wide feature flags read from environment variables. + +This module exists so the engine, the public interface, and the readiness +module can share a single source of truth for runtime mode flags without +forming an import cycle. Tests patch values here to flip behavior for a +single test scope. +""" + +from __future__ import annotations + +import os + +# Async engine is the default execution path. Set ``DATA_DESIGNER_ASYNC_ENGINE=0`` +# to opt back into the legacy sync engine for one transitional release; the sync +# path is scheduled for removal afterwards. +DATA_DESIGNER_ASYNC_ENGINE: bool = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "1") == "1" diff --git a/packages/data-designer-engine/src/data_designer/engine/readiness.py b/packages/data-designer-engine/src/data_designer/engine/readiness.py new file mode 100644 index 000000000..2c5a0a50b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/readiness.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""External-readiness checks for a DataDesigner workload. + +A "readiness" check is a pre-flight probe of every external resource a +configuration depends on: each referenced model alias is sent a tiny +generation request, and every referenced MCP tool alias is contacted to +confirm its server is reachable. + +This module hosts the shared logic invoked from two places: + +- ``DatasetBuilder.build`` / ``DatasetBuilder.build_preview`` — at the start + of a workload, to fail fast before any expensive work begins. +- ``DataDesigner.check_models`` — exposed publicly so users can verify + external dependencies are responsive without triggering a workload. + +The two callers must use the same code path here so the standalone method +cannot drift from the workload-startup gate. +""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from data_designer.engine import flags +from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated +from data_designer.engine.dataset_builders.errors import DatasetGenerationError + +if TYPE_CHECKING: + from data_designer.config.column_types import ColumnConfigT + from data_designer.engine.resources.resource_provider import ResourceProvider + +logger = logging.getLogger(__name__) + +# Match the timeout the dataset builder's startup gate has always used. +_MODEL_HEALTH_CHECK_TIMEOUT_SECONDS = 180 + + +def run_readiness_check( + column_configs: Sequence[ColumnConfigT], + resource_provider: ResourceProvider, +) -> None: + """Probe every model and MCP tool referenced by ``column_configs``. + + For each unique model alias collected from the column configs, + ``ModelRegistry.run_health_check`` (or ``arun_health_check`` on the async + engine) sends a tiny ``"Hello!"`` generation. Models whose ``ModelConfig`` + has ``skip_health_check=True`` are skipped by the registry. After the + model pass, every unique MCP tool alias is probed via + ``MCPRegistry.run_health_check``. + + Args: + column_configs: The column configs whose ``get_model_aliases()`` and + ``tool_alias`` fields determine which aliases are probed. + resource_provider: Provides access to the model registry and MCP + registry. ``mcp_registry`` may be ``None`` only if no tool + aliases are referenced. + + Raises: + Typed model errors from ``data_designer.engine.models.errors`` for + any failing model probe. + DatasetGenerationError: If a tool alias is referenced but no MCP + registry is configured on the resource provider. + TimeoutError: If async health-check execution exceeds + ``_MODEL_HEALTH_CHECK_TIMEOUT_SECONDS``. + """ + _run_model_health_check(column_configs, resource_provider) + _run_mcp_tool_health_check(column_configs, resource_provider) + + +def _run_model_health_check( + column_configs: Sequence[ColumnConfigT], + resource_provider: ResourceProvider, +) -> None: + model_aliases: set[str] = set() + for config in column_configs: + model_aliases.update(config.get_model_aliases()) + + if not model_aliases: + return + + if flags.DATA_DESIGNER_ASYNC_ENGINE: + # Defer the async-engine imports to here so users on the legacy sync + # engine never pay the import cost. Mirrors the gating in + # ``dataset_builders.dataset_builder``. + import asyncio + + from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop + + loop = ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe( + resource_provider.model_registry.arun_health_check(list(model_aliases)), + loop, + ) + try: + future.result(timeout=_MODEL_HEALTH_CHECK_TIMEOUT_SECONDS) + except TimeoutError: + future.cancel() + raise + else: + resource_provider.model_registry.run_health_check(list(model_aliases)) + + +def _run_mcp_tool_health_check( + column_configs: Sequence[ColumnConfigT], + resource_provider: ResourceProvider, +) -> None: + # Tool aliases are only meaningful on model-generated column configs. + tool_aliases = sorted( + { + config.tool_alias + for config in column_configs + if column_type_is_model_generated(config.column_type) and getattr(config, "tool_alias", None) + } + ) + if not tool_aliases: + return + if resource_provider.mcp_registry is None: + raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.") + resource_provider.mcp_registry.run_health_check(tool_aliases) diff --git a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py index 98802726a..683bb6dc2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py +++ b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py @@ -3,7 +3,6 @@ from __future__ import annotations -import os from typing import TYPE_CHECKING from data_designer.config.base import ConfigBase @@ -13,6 +12,7 @@ from data_designer.config.run_config import RunConfig from data_designer.config.seed_source import SeedSource from data_designer.config.utils.type_helpers import StrEnum +from data_designer.engine import flags from data_designer.engine.mcp.factory import create_mcp_registry from data_designer.engine.mcp.registry import MCPRegistry from data_designer.engine.model_provider import ( @@ -148,9 +148,7 @@ def create_resource_provider( # default for backward compatibility. if client_concurrency_mode is None: client_concurrency_mode = ( - ClientConcurrencyMode.ASYNC - if os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "1") == "1" - else ClientConcurrencyMode.SYNC + ClientConcurrencyMode.ASYNC if flags.DATA_DESIGNER_ASYNC_ENGINE else ClientConcurrencyMode.SYNC ) effective_run_config = run_config or RunConfig() diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 6f2c74c49..7aff4cd23 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -233,10 +233,10 @@ def __init__(self, **kwargs: object) -> None: def test_sync_path_unaffected_by_async_engine_flag() -> None: """DATA_DESIGNER_ASYNC_ENGINE=0 keeps the sync path unchanged.""" - import data_designer.engine.dataset_builders.dataset_builder as builder_mod + from data_designer.engine import flags - assert hasattr(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE") - assert isinstance(builder_mod.DATA_DESIGNER_ASYNC_ENGINE, bool) + assert hasattr(flags, "DATA_DESIGNER_ASYNC_ENGINE") + assert isinstance(flags.DATA_DESIGNER_ASYNC_ENGINE, bool) # -- Test execution graph integration with real column configs ----------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index bc6328a96..2c4e0ca0b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -20,6 +20,7 @@ from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams from data_designer.config.seed_source import LocalFileSeedSource from data_designer.config.seed_source_dataframe import DataFrameSeedSource +from data_designer.engine import flags from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder, _ConfigCompatibility from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError @@ -48,7 +49,7 @@ def _force_sync_engine(monkeypatch: pytest.MonkeyPatch) -> None: behavior; the async path has dedicated coverage in ``test_async_builder_integration.py`` and ``test_async_scheduler.py``. """ - monkeypatch.setattr(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", False) + monkeypatch.setattr(flags, "DATA_DESIGNER_ASYNC_ENGINE", False) @pytest.fixture @@ -324,61 +325,6 @@ def test_dataset_builder_build_method_basic_flow( assert result_path == stub_resource_provider.artifact_storage.final_dataset_path -def test_run_model_health_check_collects_aliases_from_get_model_aliases( - stub_resource_provider, - stub_model_configs, -) -> None: - """The health check pings every alias returned by each config's get_model_aliases(). - - Regression test for #606: secondary aliases on multi-model plugin configs (returned via - get_model_aliases()) must be passed to run_health_check(), not just the primary - model_alias field. - """ - stub_resource_provider.model_registry.run_health_check = Mock() - - @custom_column_generator(model_aliases=["custom-model-a", "custom-model-b"]) - def gen_with_two_models(row: dict, generator_params, models) -> dict: - del generator_params, models - return row - - config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) - config_builder.add_column( - SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams()) - ) - config_builder.add_column(LLMTextColumnConfig(name="builtin_llm_col", prompt="x", model_alias="builtin-model")) - config_builder.add_column(CustomColumnConfig(name="custom_col", generator_function=gen_with_two_models)) - - builder = DatasetBuilder( - data_designer_config=config_builder.build(), - resource_provider=stub_resource_provider, - ) - builder._run_model_health_check_if_needed() - - stub_resource_provider.model_registry.run_health_check.assert_called_once() - (called_aliases,), _ = stub_resource_provider.model_registry.run_health_check.call_args - assert set(called_aliases) == {"builtin-model", "custom-model-a", "custom-model-b"} - - -def test_run_model_health_check_skips_when_no_model_aliases( - stub_resource_provider, - stub_model_configs, -) -> None: - """Configs with no model aliases (e.g. samplers only) skip the health check entirely.""" - stub_resource_provider.model_registry.run_health_check = Mock() - - config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) - config_builder.add_column( - SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams()) - ) - builder = DatasetBuilder( - data_designer_config=config_builder.build(), - resource_provider=stub_resource_provider, - ) - builder._run_model_health_check_if_needed() - - stub_resource_provider.model_registry.run_health_check.assert_not_called() - - @pytest.mark.parametrize( "column_configs,expected_error", [ @@ -1609,7 +1555,7 @@ def test_build_resume_starts_fresh_without_metadata(stub_resource_provider, stub builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path) with caplog.at_level(logging.INFO): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_run_batch"): with patch.object(builder.batch_manager, "finish"): # resume=False is set internally; build dispatches to the normal (non-resume) path @@ -1667,7 +1613,7 @@ def test_build_resume_allows_larger_num_records(stub_resource_provider, stub_tes builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) with caplog.at_level(logging.WARNING): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): # 6 > 4 already generated → not already complete, should start generating # Here we just verify it does NOT raise on the num_records check with patch.object(builder, "_build_with_resume", return_value=True): @@ -1741,7 +1687,7 @@ def test_build_if_possible_starts_fresh_on_dropped_column_artifact_policy_mismat resource_provider=stub_resource_provider, ) - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_run_batch"): with patch.object(builder.batch_manager, "finish"): final_path = builder.build(num_records=4, resume=ResumeMode.IF_POSSIBLE) @@ -1995,7 +1941,7 @@ def test_build_resume_not_already_complete_when_extension_fits_in_slack( with patch.object(builder, "_run_batch") as mock_run_batch: with patch.object(builder.batch_manager, "finish"): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): builder.build(num_records=6, resume=ResumeMode.ALWAYS) mock_run_batch.assert_called_once() @@ -2024,7 +1970,7 @@ def test_build_resume_recovers_progress_from_disk_when_metadata_lags( with caplog.at_level(logging.WARNING): with patch.object(builder, "_run_batch") as mock_run_batch: with patch.object(builder.batch_manager, "finish"): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): builder.build(num_records=4, resume=ResumeMode.ALWAYS) mock_run_batch.assert_not_called() @@ -2068,8 +2014,8 @@ def test_build_async_resume_logs_warning_when_already_complete( builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) with caplog.at_level(logging.WARNING): - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "run_readiness_check"): builder.build(num_records=4, resume=ResumeMode.ALWAYS) assert any("already complete" in record.message for record in caplog.records) @@ -2091,8 +2037,8 @@ def test_build_async_resume_starts_fresh_without_metadata( builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path) with caplog.at_level(logging.INFO): - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_build_async", return_value=True) as mock_async: builder.build(num_records=4, resume=ResumeMode.ALWAYS) @@ -2112,8 +2058,8 @@ def test_build_async_resume_already_complete_does_not_run_after_generation_proce builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder._processor_runner, "run_after_generation") as mock_after: builder.build(num_records=4, resume=ResumeMode.ALWAYS) @@ -2138,8 +2084,8 @@ def test_find_completed_row_groups_used_for_initial_total_batches( builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) # Both row groups are on disk → dataset is already complete → generated=False - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder._processor_runner, "run_after_generation") as mock_after: builder.build(num_records=4, resume=ResumeMode.ALWAYS) @@ -2186,11 +2132,11 @@ def capturing_prepare(*args, **kwargs): # asyncio and ensure_async_engine_loop are lazy-imported in dataset_builder only when # DATA_DESIGNER_ASYNC_ENGINE=True at module load time. Inject them for the duration # of this test so _build_async can proceed past the early-return path. - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): builder.build(num_records=6, resume=ResumeMode.ALWAYS) @@ -2228,11 +2174,11 @@ def capturing_prepare(*args, **kwargs): mock_future = Mock() mock_future.result = Mock(return_value=None) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): builder.build(num_records=6, resume=ResumeMode.ALWAYS) @@ -2271,11 +2217,11 @@ def capturing_prepare(*args, **kwargs): mock_future = Mock() mock_future.result = Mock(return_value=None) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): # Extend the dataset: new target is 7, original was 5 builder.build(num_records=7, resume=ResumeMode.ALWAYS) @@ -2319,11 +2265,11 @@ def capturing_prepare(*args, **kwargs): mock_future = Mock() mock_future.result = Mock(return_value=None) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): builder.build(num_records=9, resume=ResumeMode.ALWAYS) @@ -2374,11 +2320,11 @@ def capturing_prepare(*args, **kwargs): mock_future = Mock() mock_future.result = Mock(return_value=None) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): builder.build(num_records=9, resume=ResumeMode.ALWAYS) @@ -2417,11 +2363,11 @@ def capturing_prepare(*args, **kwargs): mock_future = Mock() mock_future.result = Mock(return_value=None) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): builder.build(num_records=6, resume=ResumeMode.ALWAYS) @@ -2462,11 +2408,11 @@ def capturing_prepare(*args, **kwargs): mock_future = Mock() mock_future.result = Mock(return_value=None) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): builder.build(num_records=7, resume=ResumeMode.ALWAYS) @@ -2501,11 +2447,11 @@ def capturing_prepare(*args, **kwargs): mock_future = Mock() mock_future.result = Mock(return_value=None) - with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(flags, "DATA_DESIGNER_ASYNC_ENGINE", True): with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): - with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder_mod, "run_readiness_check"): with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare) as mock_prepare: builder.build(num_records=6, resume=ResumeMode.ALWAYS) @@ -2542,14 +2488,13 @@ def test_if_possible_incompatible_config_does_not_overwrite_existing_dataset( # Simulate incompatible config and mock out all I/O so build() does not actually generate data with patch.object(builder, "_check_resume_config_compatibility", return_value=_ConfigCompatibility.INCOMPATIBLE): - with patch.object(builder, "_run_model_health_check_if_needed"): - with patch.object(builder, "_run_mcp_tool_check_if_needed"): - with patch.object(builder, "_write_builder_config"): - with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): - with patch.object(builder.batch_manager, "start"): - with patch.object(builder.batch_manager, "finish"): - with patch.object(builder._processor_runner, "run_after_generation"): - builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + with patch.object(builder_mod, "run_readiness_check"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) # artifact_storage.resume must be downgraded to NEVER so resolved_dataset_name uses NEVER semantics assert storage.resume == ResumeMode.NEVER @@ -2592,14 +2537,13 @@ def test_if_possible_incompatible_config_refreshes_media_storage_path( ) with patch.object(builder, "_check_resume_config_compatibility", return_value=_ConfigCompatibility.INCOMPATIBLE): - with patch.object(builder, "_run_model_health_check_if_needed"): - with patch.object(builder, "_run_mcp_tool_check_if_needed"): - with patch.object(builder, "_write_builder_config"): - with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): - with patch.object(builder.batch_manager, "start"): - with patch.object(builder.batch_manager, "finish"): - with patch.object(builder._processor_runner, "run_after_generation"): - builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + with patch.object(builder_mod, "run_readiness_check"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) new_media_base = storage.media_storage.base_path assert new_media_base != original_media_base, ( @@ -2629,14 +2573,13 @@ def test_if_possible_starts_fresh_when_no_existing_directory( resource_provider=stub_resource_provider, ) - with patch.object(builder, "_run_model_health_check_if_needed"): - with patch.object(builder, "_run_mcp_tool_check_if_needed"): - with patch.object(builder, "_write_builder_config"): - with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): - with patch.object(builder.batch_manager, "start"): - with patch.object(builder.batch_manager, "finish"): - with patch.object(builder._processor_runner, "run_after_generation"): - builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + with patch.object(builder_mod, "run_readiness_check"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) assert storage.resume == ResumeMode.NEVER @@ -2662,13 +2605,12 @@ def test_if_possible_starts_fresh_when_directory_is_empty(stub_resource_provider resource_provider=stub_resource_provider, ) - with patch.object(builder, "_run_model_health_check_if_needed"): - with patch.object(builder, "_run_mcp_tool_check_if_needed"): - with patch.object(builder, "_write_builder_config"): - with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): - with patch.object(builder.batch_manager, "start"): - with patch.object(builder.batch_manager, "finish"): - with patch.object(builder._processor_runner, "run_after_generation"): - builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + with patch.object(builder_mod, "run_readiness_check"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) assert storage.resume == ResumeMode.NEVER diff --git a/packages/data-designer-engine/tests/engine/test_readiness.py b/packages/data-designer-engine/tests/engine/test_readiness.py new file mode 100644 index 000000000..13711d265 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/test_readiness.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Sequence +from unittest.mock import Mock, patch + +import pytest + +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig +from data_designer.config.column_types import ColumnConfigT +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.custom_column import custom_column_generator +from data_designer.config.models import ModelConfig +from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams +from data_designer.engine import flags +from data_designer.engine.dataset_builders.errors import DatasetGenerationError +from data_designer.engine.mcp.registry import MCPRegistry +from data_designer.engine.readiness import run_readiness_check + + +@pytest.fixture(autouse=True) +def _force_sync_engine(monkeypatch: pytest.MonkeyPatch) -> None: + """Pin readiness tests to the sync engine. + + Lets us assert against ``run_health_check`` directly without standing up + an event loop. + """ + monkeypatch.setattr(flags, "DATA_DESIGNER_ASYNC_ENGINE", False) + + +def _build_columns( + *, + model_configs: list[ModelConfig], + llm_columns: Sequence[tuple[str, str]] = (), + include_sampler: bool = True, +) -> list[ColumnConfigT]: + """Build a ``DataDesignerConfig`` and return its (already-flat) column configs. + + ``llm_columns`` is a list of ``(name, model_alias)`` pairs. ``include_sampler`` + adds a UUID sampler column at the start so configs that use a seed-only fast + path are still well-formed. + """ + builder = DataDesignerConfigBuilder(model_configs=model_configs) + if include_sampler: + builder.add_column( + SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams()) + ) + for name, model_alias in llm_columns: + builder.add_column(LLMTextColumnConfig(name=name, prompt="x", model_alias=model_alias)) + return builder.build().columns + + +# --------------------------------------------------------------------------- +# Model health check +# --------------------------------------------------------------------------- + + +def test_run_readiness_check_collects_aliases_from_get_model_aliases( + stub_resource_provider, + stub_model_configs, +) -> None: + """The model probe pings every alias returned by each config's ``get_model_aliases()``. + + Regression coverage for #606 — secondary aliases on multi-model plugin configs + (returned via ``get_model_aliases()``) must be passed to ``run_health_check()``, + not just the primary ``model_alias`` field. + """ + stub_resource_provider.model_registry.run_health_check = Mock() + stub_resource_provider.mcp_registry = None + + @custom_column_generator(model_aliases=["custom-model-a", "custom-model-b"]) + def _gen_with_two_models(row, generator_params, models): + del generator_params, models + return row + + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + builder.add_column(SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())) + builder.add_column(LLMTextColumnConfig(name="builtin", prompt="x", model_alias="builtin-model")) + from data_designer.config.column_configs import CustomColumnConfig + + builder.add_column(CustomColumnConfig(name="custom_col", generator_function=_gen_with_two_models)) + + run_readiness_check(builder.build().columns, stub_resource_provider) + + stub_resource_provider.model_registry.run_health_check.assert_called_once() + (called_aliases,), _ = stub_resource_provider.model_registry.run_health_check.call_args + assert set(called_aliases) == {"builtin-model", "custom-model-a", "custom-model-b"} + + +def test_run_readiness_check_skips_model_probe_when_no_aliases( + stub_resource_provider, + stub_model_configs, +) -> None: + """Configs with no model aliases (samplers only) skip the model health check entirely.""" + stub_resource_provider.model_registry.run_health_check = Mock() + stub_resource_provider.mcp_registry = None + + columns = _build_columns(model_configs=stub_model_configs, llm_columns=[]) + + run_readiness_check(columns, stub_resource_provider) + + stub_resource_provider.model_registry.run_health_check.assert_not_called() + + +def test_run_readiness_check_propagates_model_probe_error( + stub_resource_provider, + stub_model_configs, +) -> None: + """Exceptions from ``run_health_check`` bubble up unchanged for caller branching.""" + from data_designer.engine.models.errors import ModelAuthenticationError + + stub_resource_provider.model_registry.run_health_check = Mock(side_effect=ModelAuthenticationError("bad creds")) + stub_resource_provider.mcp_registry = None + + columns = _build_columns(model_configs=stub_model_configs, llm_columns=[("col", "stub-text")]) + + with pytest.raises(ModelAuthenticationError, match="bad creds"): + run_readiness_check(columns, stub_resource_provider) + + +# --------------------------------------------------------------------------- +# MCP tool health check +# --------------------------------------------------------------------------- + + +def test_run_readiness_check_collects_unique_sorted_tool_aliases( + stub_resource_provider, + stub_model_configs, +) -> None: + """Tool probes are called once per unique alias, sorted, after the model probe.""" + stub_resource_provider.model_registry.run_health_check = Mock() + mock_mcp_registry = Mock(spec=MCPRegistry) + stub_resource_provider.mcp_registry = mock_mcp_registry + + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + builder.add_column(SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())) + builder.add_column(LLMTextColumnConfig(name="a", prompt="x", model_alias="stub-text", tool_alias="zebra")) + builder.add_column(LLMTextColumnConfig(name="b", prompt="x", model_alias="stub-text", tool_alias="alpha")) + builder.add_column( + LLMTextColumnConfig(name="c", prompt="x", model_alias="stub-text", tool_alias="alpha") # duplicate + ) + + run_readiness_check(builder.build().columns, stub_resource_provider) + + mock_mcp_registry.run_health_check.assert_called_once_with(["alpha", "zebra"]) + + +def test_run_readiness_check_skips_tool_probe_when_no_tool_aliases( + stub_resource_provider, + stub_model_configs, +) -> None: + """Configs with no tool aliases never touch the MCP registry.""" + stub_resource_provider.model_registry.run_health_check = Mock() + mock_mcp_registry = Mock(spec=MCPRegistry) + stub_resource_provider.mcp_registry = mock_mcp_registry + + columns = _build_columns(model_configs=stub_model_configs, llm_columns=[("col", "stub-text")]) + + run_readiness_check(columns, stub_resource_provider) + + mock_mcp_registry.run_health_check.assert_not_called() + + +def test_run_readiness_check_raises_when_tools_referenced_but_no_mcp_registry( + stub_resource_provider, + stub_model_configs, +) -> None: + """Tool aliases are referenced but ``mcp_registry`` is ``None`` — must fail loudly.""" + stub_resource_provider.model_registry.run_health_check = Mock() + stub_resource_provider.mcp_registry = None + + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + builder.add_column(SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())) + builder.add_column(LLMTextColumnConfig(name="col", prompt="x", model_alias="stub-text", tool_alias="missing-tools")) + + with pytest.raises(DatasetGenerationError, match="missing-tools"): + run_readiness_check(builder.build().columns, stub_resource_provider) + + +def test_run_readiness_check_propagates_tool_probe_error( + stub_resource_provider, + stub_model_configs, +) -> None: + """Exceptions from MCP ``run_health_check`` bubble up unchanged.""" + stub_resource_provider.model_registry.run_health_check = Mock() + mock_mcp_registry = Mock(spec=MCPRegistry) + mock_mcp_registry.run_health_check = Mock(side_effect=RuntimeError("mcp down")) + stub_resource_provider.mcp_registry = mock_mcp_registry + + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + builder.add_column(SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())) + builder.add_column(LLMTextColumnConfig(name="col", prompt="x", model_alias="stub-text", tool_alias="tools")) + + with pytest.raises(RuntimeError, match="mcp down"): + run_readiness_check(builder.build().columns, stub_resource_provider) + + +# --------------------------------------------------------------------------- +# Ordering +# --------------------------------------------------------------------------- + + +def test_run_readiness_check_runs_models_before_tools( + stub_resource_provider, + stub_model_configs, +) -> None: + """The model probe must run first; an MCP failure is irrelevant if models fail first.""" + from data_designer.engine.models.errors import ModelAuthenticationError + + stub_resource_provider.model_registry.run_health_check = Mock(side_effect=ModelAuthenticationError("bad creds")) + mock_mcp_registry = Mock(spec=MCPRegistry) + stub_resource_provider.mcp_registry = mock_mcp_registry + + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + builder.add_column(SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())) + builder.add_column(LLMTextColumnConfig(name="col", prompt="x", model_alias="stub-text", tool_alias="tools")) + + with pytest.raises(ModelAuthenticationError): + run_readiness_check(builder.build().columns, stub_resource_provider) + + # The MCP probe must not have been reached. + mock_mcp_registry.run_health_check.assert_not_called() + + +def test_run_readiness_check_no_models_no_tools_is_noop( + stub_resource_provider, + stub_model_configs, +) -> None: + """A pure-sampler config touches neither registry.""" + stub_resource_provider.model_registry.run_health_check = Mock() + mock_mcp_registry = Mock(spec=MCPRegistry) + stub_resource_provider.mcp_registry = mock_mcp_registry + + columns = _build_columns(model_configs=stub_model_configs, llm_columns=[]) + + run_readiness_check(columns, stub_resource_provider) + + stub_resource_provider.model_registry.run_health_check.assert_not_called() + mock_mcp_registry.run_health_check.assert_not_called() + + +# --------------------------------------------------------------------------- +# Column-type coverage +# --------------------------------------------------------------------------- + + +def test_run_readiness_check_collects_image_model_aliases( + stub_resource_provider, + stub_model_configs, +) -> None: + """Image-generation columns contribute their model aliases like LLM columns do. + + The dataset builder dispatches probes by ``model_generation_type`` inside + ``ModelRegistry.run_health_check``; readiness is generation-type-agnostic + and must surface every alias regardless of column kind. + """ + from data_designer.config.column_configs import ImageColumnConfig + + stub_resource_provider.model_registry.run_health_check = Mock() + stub_resource_provider.mcp_registry = None + + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + builder.add_column(SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())) + builder.add_column(LLMTextColumnConfig(name="caption", prompt="x", model_alias="stub-text")) + builder.add_column(ImageColumnConfig(name="picture", prompt="y", model_alias="stub-image")) + + run_readiness_check(builder.build().columns, stub_resource_provider) + + stub_resource_provider.model_registry.run_health_check.assert_called_once() + (called_aliases,), _ = stub_resource_provider.model_registry.run_health_check.call_args + assert set(called_aliases) == {"stub-text", "stub-image"} + + +def test_run_readiness_check_passes_skip_flagged_aliases_to_registry( + stub_resource_provider, + stub_model_configs, +) -> None: + """Readiness does not pre-filter ``skip_health_check=True`` aliases. + + The skip decision lives in ``ModelRegistry.run_health_check`` (covered by + ``test_model_registry``). Readiness's contract is "pass every referenced + alias through and let the registry decide" — verified here so future edits + don't accidentally start filtering at this layer. + """ + stub_resource_provider.model_registry.run_health_check = Mock() + stub_resource_provider.mcp_registry = None + + columns = _build_columns( + model_configs=stub_model_configs, + llm_columns=[("col", "stub-text")], + ) + + run_readiness_check(columns, stub_resource_provider) + + stub_resource_provider.model_registry.run_health_check.assert_called_once_with(["stub-text"]) + + +# --------------------------------------------------------------------------- +# Async dispatch +# --------------------------------------------------------------------------- + + +def test_run_readiness_check_dispatches_to_async_registry_under_async_engine( + stub_resource_provider, + stub_model_configs, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When the async engine is selected, model probes route through ``arun_health_check``. + + The autouse fixture pins sync; this test overrides for the async path so the + branch in ``readiness._run_model_health_check`` gets coverage. + """ + monkeypatch.setattr(flags, "DATA_DESIGNER_ASYNC_ENGINE", True) + stub_resource_provider.model_registry.arun_health_check = Mock() + stub_resource_provider.mcp_registry = None + + columns = _build_columns( + model_configs=stub_model_configs, + llm_columns=[("col", "stub-text")], + ) + + # ``run_coroutine_threadsafe`` returns a Future; we want the readiness wrapper + # to call ``.result(timeout=...)`` on it, so install a Mock future whose + # ``.result`` returns ``None`` (success). + sentinel_future = Mock() + sentinel_future.result.return_value = None + + fake_loop = Mock() + + with ( + patch("data_designer.engine.readiness.ensure_async_engine_loop", return_value=fake_loop, create=True), + patch("asyncio.run_coroutine_threadsafe", return_value=sentinel_future) as mock_submit, + ): + run_readiness_check(columns, stub_resource_provider) + + # The async coroutine was created from arun_health_check and submitted to the loop. + stub_resource_provider.model_registry.arun_health_check.assert_called_once_with(["stub-text"]) + mock_submit.assert_called_once() + sentinel_future.result.assert_called_once_with(timeout=180) + + +def test_run_readiness_check_cancels_future_and_reraises_on_timeout( + stub_resource_provider, + stub_model_configs, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A 180-second timeout cancels the future and re-raises ``TimeoutError``.""" + monkeypatch.setattr(flags, "DATA_DESIGNER_ASYNC_ENGINE", True) + stub_resource_provider.model_registry.arun_health_check = Mock() + stub_resource_provider.mcp_registry = None + + columns = _build_columns( + model_configs=stub_model_configs, + llm_columns=[("col", "stub-text")], + ) + + sentinel_future = Mock() + sentinel_future.result.side_effect = TimeoutError() + + with ( + patch("data_designer.engine.readiness.ensure_async_engine_loop", return_value=Mock(), create=True), + patch("asyncio.run_coroutine_threadsafe", return_value=sentinel_future), + pytest.raises(TimeoutError), + ): + run_readiness_check(columns, stub_resource_provider) + + sentinel_future.cancel.assert_called_once() diff --git a/packages/data-designer/src/data_designer/cli/commands/check_models.py b/packages/data-designer/src/data_designer/cli/commands/check_models.py new file mode 100644 index 000000000..0ae40902d --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/commands/check_models.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import typer + +from data_designer.cli.controllers.generation_controller import GenerationController + + +def check_models_command( + config_source: str = typer.Argument( + help=( + "Path or URL to a config file (.yaml/.yml/.json), or a local Python module (.py)" + " that defines a load_config_builder() function." + ), + ), +) -> None: + """Check that every model and MCP tool referenced by the configuration is reachable. + + Runs the same readiness probes performed at the start of ``preview`` and + ``create``: a tiny generation against each referenced model alias, plus a + connectivity probe to each referenced MCP tool. Models with + ``skip_health_check=True`` are skipped. + + Complements ``validate``: ``validate`` checks the configuration is + well-formed (internal readiness); ``check-models`` checks the providers + it depends on are responsive (external readiness). + + Examples: + # Check models referenced by a YAML config + data-designer check-models my_config.yaml + + # Check models referenced by a remote config URL + data-designer check-models https://example.com/my_config.yaml + + # Check models referenced by a Python module + data-designer check-models my_config.py + """ + controller = GenerationController() + controller.run_check_models(config_source=config_source) diff --git a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py index 4a4231c41..cc492dd1b 100644 --- a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py +++ b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py @@ -17,6 +17,7 @@ from data_designer.config.errors import InvalidConfigError from data_designer.config.utils.constants import DEFAULT_DISPLAY_WIDTH from data_designer.engine.storage.artifact_storage import ResumeMode +from data_designer.errors import DataDesignerError from data_designer.interface import DataDesigner from data_designer.logging import LOG_INDENT @@ -111,6 +112,34 @@ def run_validate(self, config_source: str) -> None: print_success("Configuration is valid") + def run_check_models(self, config_source: str) -> None: + """Load config and probe every referenced model and MCP tool. + + Complements ``run_validate``: validate covers internal readiness + (configuration well-formedness); this covers external readiness + (provider liveness). + + Args: + config_source: Path to a config file or Python module. + """ + config_builder = self._load_config(config_source) + + print_header("Data Designer Check Models") + console.print(f" Config: [bold]{config_source}[/bold]") + console.print() + + try: + data_designer = DataDesigner() + data_designer.check_models(config_builder) + except DataDesignerError as e: + print_error(f"Model health check failed ({type(e).__name__}): {e}") + raise typer.Exit(code=1) + except Exception as e: + print_error(f"Model health check failed: {e}") + raise typer.Exit(code=1) + + print_success("All models and tools responded successfully") + def run_create( self, config_source: str, diff --git a/packages/data-designer/src/data_designer/cli/main.py b/packages/data-designer/src/data_designer/cli/main.py index 82fbe430b..3563c6913 100644 --- a/packages/data-designer/src/data_designer/cli/main.py +++ b/packages/data-designer/src/data_designer/cli/main.py @@ -76,6 +76,12 @@ def _is_version_request(args: list[str]) -> bool: "help": "Validate a Data Designer configuration", "rich_help_panel": "Generation", }, + "check-models": { + "module": f"{_CMD}.check_models", + "attr": "check_models_command", + "help": "Check that every referenced model and MCP tool is reachable", + "rich_help_panel": "Generation", + }, } ), add_completion=False, diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index 535e2391c..fd39ec40c 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -37,12 +37,14 @@ MODEL_PROVIDERS_FILE_PATH, ) from data_designer.config.utils.info import InfoType, InterfaceInfo +from data_designer.engine import flags from data_designer.engine.analysis.dataset_profiler import DataDesignerDatasetProfiler, DatasetProfilerConfig from data_designer.engine.compiler import compile_data_designer_config -from data_designer.engine.dataset_builders.dataset_builder import DATA_DESIGNER_ASYNC_ENGINE, DatasetBuilder +from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder from data_designer.engine.mcp.io import list_tool_names from data_designer.engine.model_provider import ModelProviderRegistry, resolve_model_provider_registry from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode +from data_designer.engine.readiness import run_readiness_check from data_designer.engine.resources.person_reader import ( PersonReader, create_person_reader, @@ -547,6 +549,37 @@ def validate(self, config_builder: DataDesignerConfigBuilder) -> None: resource_provider = self._create_resource_provider("validate-configuration", config_builder) compile_data_designer_config(config_builder.build(), resource_provider) + def check_models(self, config_builder: DataDesignerConfigBuilder) -> None: + """Probe every model and MCP tool referenced by the configuration. + + Runs the same readiness checks performed at the start of ``preview`` and + ``create``: a tiny generation against each referenced model alias, plus a + connectivity probe to each referenced MCP tool. Models whose ``ModelConfig`` + has ``skip_health_check=True`` are skipped. + + This complements :meth:`validate`: ``validate`` answers "is my configuration + well-formed?", ``check_models`` answers "are the providers it depends on + actually responsive?". Together they cover internal and external readiness + without needing to start a workload. + + Args: + config_builder: The DataDesignerConfigBuilder whose column configs + determine which model aliases and tool aliases are probed. + + Returns: + None if every (non-skipped) probe succeeded. + + Raises: + ModelAuthenticationError, ModelNotFoundError, ModelAPIConnectionError, + and other typed errors from ``data_designer.engine.models.errors`` + for any failing model probe. + DatasetGenerationError: If a tool alias is referenced but no + ``MCPRegistry`` is configured. + TimeoutError: If async health-check execution exceeds 180 seconds. + """ + resource_provider = self._create_resource_provider("check-models", config_builder) + run_readiness_check(config_builder.build().columns, resource_provider) + def get_default_model_configs(self) -> list[ModelConfig]: """Get the default model configurations. @@ -722,7 +755,7 @@ def _resolve_client_concurrency_mode(config_builder: DataDesignerConfigBuilder) from inside the sync engine. Match the client mode to the actual engine choice so the fallback path is functional. """ - if not DATA_DESIGNER_ASYNC_ENGINE: + if not flags.DATA_DESIGNER_ASYNC_ENGINE: # Deliberate opt-out via env var. Surface the deprecation so users # know the sync path is going away. Mirror the ``allow_resize`` shape # in ``_resolve_async_compatibility``: emit both a ``logger.warning`` diff --git a/packages/data-designer/tests/cli/commands/test_check_models_command.py b/packages/data-designer/tests/cli/commands/test_check_models_command.py new file mode 100644 index 000000000..cdeb8a279 --- /dev/null +++ b/packages/data-designer/tests/cli/commands/test_check_models_command.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from data_designer.cli.commands.check_models import check_models_command + +# --------------------------------------------------------------------------- +# check_models_command delegation tests +# --------------------------------------------------------------------------- + + +@patch("data_designer.cli.commands.check_models.GenerationController") +def test_check_models_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> None: + """check_models_command delegates to GenerationController.run_check_models.""" + mock_ctrl = MagicMock() + mock_ctrl_cls.return_value = mock_ctrl + + check_models_command(config_source="config.yaml") + + mock_ctrl_cls.assert_called_once() + mock_ctrl.run_check_models.assert_called_once_with(config_source="config.yaml") + + +@patch("data_designer.cli.commands.check_models.GenerationController") +def test_check_models_command_passes_python_module_source(mock_ctrl_cls: MagicMock) -> None: + """check_models_command passes a .py source to the controller.""" + mock_ctrl = MagicMock() + mock_ctrl_cls.return_value = mock_ctrl + + check_models_command(config_source="my_config.py") + + mock_ctrl.run_check_models.assert_called_once_with(config_source="my_config.py") diff --git a/packages/data-designer/tests/cli/controllers/test_generation_controller.py b/packages/data-designer/tests/cli/controllers/test_generation_controller.py index b8047641a..a8c971b73 100644 --- a/packages/data-designer/tests/cli/controllers/test_generation_controller.py +++ b/packages/data-designer/tests/cli/controllers/test_generation_controller.py @@ -662,6 +662,85 @@ def test_run_validate_generic_exception(mock_load_config: MagicMock, mock_dd_cls assert exc_info.value.exit_code == 1 +# --------------------------------------------------------------------------- +# run_check_models tests +# --------------------------------------------------------------------------- + + +@patch(f"{_CTRL}.DataDesigner") +@patch(f"{_CTRL}.load_config_builder") +def test_run_check_models_success(mock_load_config: MagicMock, mock_dd_cls: MagicMock) -> None: + """Test successful check_models execution delegates to DataDesigner.check_models.""" + mock_builder = MagicMock(spec=DataDesignerConfigBuilder) + mock_load_config.return_value = mock_builder + + mock_dd = MagicMock() + mock_dd_cls.return_value = mock_dd + mock_dd.check_models.return_value = None + + controller = GenerationController() + controller.run_check_models(config_source="config.yaml") + + mock_load_config.assert_called_once_with("config.yaml") + mock_dd_cls.assert_called_once() + mock_dd.check_models.assert_called_once_with(mock_builder) + + +@patch(f"{_CTRL}.load_config_builder") +def test_run_check_models_config_load_error(mock_load_config: MagicMock) -> None: + """check_models exits with code 1 when config fails to load.""" + mock_load_config.side_effect = ConfigLoadError("File not found") + + controller = GenerationController() + with pytest.raises(typer.Exit) as exc_info: + controller.run_check_models(config_source="missing.yaml") + + assert exc_info.value.exit_code == 1 + + +@patch(f"{_CTRL}.DataDesigner") +@patch(f"{_CTRL}.load_config_builder") +def test_run_check_models_health_check_failure(mock_load_config: MagicMock, mock_dd_cls: MagicMock) -> None: + """check_models exits with code 1 when a probe fails with a generic exception.""" + mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) + mock_dd = MagicMock() + mock_dd_cls.return_value = mock_dd + mock_dd.check_models.side_effect = RuntimeError("auth failed") + + controller = GenerationController() + with pytest.raises(typer.Exit) as exc_info: + controller.run_check_models(config_source="config.yaml") + + assert exc_info.value.exit_code == 1 + + +@patch(f"{_CTRL}.DataDesigner") +@patch(f"{_CTRL}.load_config_builder") +def test_run_check_models_typed_error_includes_class_name( + mock_load_config: MagicMock, mock_dd_cls: MagicMock, capsys: pytest.CaptureFixture[str] +) -> None: + """Typed engine errors exit 1 and surface the error class name to the user. + + Without this, an authentication failure and a connection failure look identical + on the terminal, defeating the purpose of typed engine errors. + """ + from data_designer.engine.models.errors import ModelAuthenticationError + + mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) + mock_dd = MagicMock() + mock_dd_cls.return_value = mock_dd + mock_dd.check_models.side_effect = ModelAuthenticationError("bad creds") + + controller = GenerationController() + with pytest.raises(typer.Exit) as exc_info: + controller.run_check_models(config_source="config.yaml") + + assert exc_info.value.exit_code == 1 + captured = capsys.readouterr() + assert "ModelAuthenticationError" in captured.out + assert "bad creds" in captured.out + + # --------------------------------------------------------------------------- # run_create tests # --------------------------------------------------------------------------- diff --git a/packages/data-designer/tests/cli/test_main.py b/packages/data-designer/tests/cli/test_main.py index 928e85159..dd97c2269 100644 --- a/packages/data-designer/tests/cli/test_main.py +++ b/packages/data-designer/tests/cli/test_main.py @@ -180,6 +180,18 @@ def test_app_dispatches_lazy_create_command(mock_controller_cls: Mock) -> None: ) +@patch("data_designer.cli.commands.check_models.GenerationController") +def test_app_dispatches_lazy_check_models_command(mock_controller_cls: Mock) -> None: + """The Typer app dispatches the lazy-loaded check-models command.""" + mock_controller = Mock() + mock_controller_cls.return_value = mock_controller + + result = runner.invoke(app, ["check-models", "config.yaml"]) + + assert result.exit_code == 0 + mock_controller.run_check_models.assert_called_once_with(config_source="config.yaml") + + @patch("data_designer.cli.commands.plugin.PluginCatalogController") def test_app_dispatches_lazy_plugin_list_command(mock_controller_cls: Mock) -> None: """The plugin group lazily resolves command callbacks without loading a catalog.""" diff --git a/packages/data-designer/tests/interface/test_data_designer.py b/packages/data-designer/tests/interface/test_data_designer.py index 06fa5ce11..901a6ba34 100644 --- a/packages/data-designer/tests/interface/test_data_designer.py +++ b/packages/data-designer/tests/interface/test_data_designer.py @@ -17,11 +17,16 @@ import data_designer.interface.data_designer as dd_mod import data_designer.lazy_heavy_imports as lazy -from data_designer.config.column_configs import CustomColumnConfig, ExpressionColumnConfig, SamplerColumnConfig +from data_designer.config.column_configs import ( + CustomColumnConfig, + ExpressionColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, +) from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.custom_column import custom_column_generator from data_designer.config.errors import InvalidConfigError -from data_designer.config.models import ModelProvider +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.config.run_config import JinjaRenderingEngine, RequestAdmissionTuningConfig, RunConfig from data_designer.config.sampler_params import CategorySamplerParams, DatetimeSamplerParams, SamplerType @@ -33,6 +38,7 @@ FileContentsSeedSource, HuggingFaceSeedSource, ) +from data_designer.engine import flags from data_designer.engine.resources.seed_reader import ( FileSystemSeedReader, SeedReaderError, @@ -444,7 +450,7 @@ def test_resolve_client_concurrency_mode_matches_engine_choice( (``allow_resize=True``) does not double-warn here; the builder layer emits its own warning when the run actually executes. """ - monkeypatch.setattr(dd_mod, "DATA_DESIGNER_ASYNC_ENGINE", env_value == "1") + monkeypatch.setattr(flags, "DATA_DESIGNER_ASYNC_ENGINE", env_value == "1") builder = _builder_with_allow_resize() if with_allow_resize else DataDesignerConfigBuilder() if not with_allow_resize: builder.add_column( @@ -1495,6 +1501,144 @@ def test_preview_with_dropped_columns( ) +@pytest.fixture +def stub_check_models_model_configs() -> list[ModelConfig]: + """Model configs whose ``provider`` field matches the local ``stub_model_providers`` fixture. + + The shared ``stub_model_configs`` fixture targets ``provider-1``; this file's + ``stub_model_providers`` defines ``stub-model-provider``. ``check_models`` builds + a real ``ResourceProvider``, so the two need to align. + """ + return [ + ModelConfig( + alias="stub-model", + model="stub-model", + provider="stub-model-provider", + inference_parameters=ChatCompletionInferenceParams( + temperature=0.9, + top_p=0.9, + max_tokens=2048, + ), + ) + ] + + +def test_check_models_invokes_readiness_check( + stub_artifact_path, + stub_model_providers, + stub_check_models_model_configs, + stub_managed_assets_path, +): + """check_models constructs a ResourceProvider and delegates to run_readiness_check.""" + config_builder = DataDesignerConfigBuilder(model_configs=stub_check_models_model_configs) + config_builder.add_column(LLMTextColumnConfig(name="text", prompt="x", model_alias="stub-model")) + + data_designer = DataDesigner( + artifact_path=stub_artifact_path, + model_providers=stub_model_providers, + secret_resolver=PlaintextResolver(), + managed_assets_path=stub_managed_assets_path, + ) + + with patch("data_designer.interface.data_designer.run_readiness_check") as mock_check: + data_designer.check_models(config_builder) + + assert mock_check.call_count == 1 + (called_columns, called_resource_provider), _ = mock_check.call_args + assert [c.name for c in called_columns] == ["text"] + assert called_resource_provider is not None + + +def test_check_models_propagates_typed_model_error( + stub_artifact_path, + stub_model_providers, + stub_check_models_model_configs, + stub_managed_assets_path, +): + """Errors from the readiness probe surface unchanged so callers can branch on type.""" + from data_designer.engine.models.errors import ModelAuthenticationError + + config_builder = DataDesignerConfigBuilder(model_configs=stub_check_models_model_configs) + config_builder.add_column(LLMTextColumnConfig(name="text", prompt="x", model_alias="stub-model")) + + data_designer = DataDesigner( + artifact_path=stub_artifact_path, + model_providers=stub_model_providers, + secret_resolver=PlaintextResolver(), + managed_assets_path=stub_managed_assets_path, + ) + + with patch( + "data_designer.interface.data_designer.run_readiness_check", + side_effect=ModelAuthenticationError("bad creds"), + ): + with pytest.raises(ModelAuthenticationError, match="bad creds"): + data_designer.check_models(config_builder) + + +def test_check_models_passes_built_columns_to_readiness_check( + stub_artifact_path, + stub_model_providers, + stub_check_models_model_configs, + stub_managed_assets_path, +): + """The columns argument is the materialized config's column list, not the builder.""" + config_builder = DataDesignerConfigBuilder(model_configs=stub_check_models_model_configs) + config_builder.add_column( + SamplerColumnConfig( + name="city", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["nyc", "la"]), + ) + ) + config_builder.add_column(LLMTextColumnConfig(name="story", prompt="x", model_alias="stub-model")) + + data_designer = DataDesigner( + artifact_path=stub_artifact_path, + model_providers=stub_model_providers, + secret_resolver=PlaintextResolver(), + managed_assets_path=stub_managed_assets_path, + ) + + with patch("data_designer.interface.data_designer.run_readiness_check") as mock_check: + data_designer.check_models(config_builder) + + (called_columns, _), _ = mock_check.call_args + assert [c.name for c in called_columns] == ["city", "story"] + + +def test_check_models_no_op_when_only_samplers( + stub_artifact_path, + stub_model_providers, + stub_check_models_model_configs, + stub_managed_assets_path, +): + """A config with no model-using columns still routes through readiness, which short-circuits.""" + config_builder = DataDesignerConfigBuilder(model_configs=stub_check_models_model_configs) + config_builder.add_column( + SamplerColumnConfig( + name="city", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["nyc", "la"]), + ) + ) + + data_designer = DataDesigner( + artifact_path=stub_artifact_path, + model_providers=stub_model_providers, + secret_resolver=PlaintextResolver(), + managed_assets_path=stub_managed_assets_path, + ) + + with patch("data_designer.interface.data_designer.run_readiness_check") as mock_check: + data_designer.check_models(config_builder) + + # check_models always invokes run_readiness_check; the readiness function itself + # short-circuits when no aliases are referenced. The contract here is "we always + # delegate"; the no-op decision lives in the engine. + assert mock_check.call_count == 1 + + def test_validate_raises_error_when_seed_collides( stub_artifact_path, stub_model_providers, diff --git a/plans/check-models/check-models.plan.md b/plans/check-models/check-models.plan.md new file mode 100644 index 000000000..e494b8fe2 --- /dev/null +++ b/plans/check-models/check-models.plan.md @@ -0,0 +1,288 @@ +--- +date: 2026-05-27 +status: draft +authors: + - mknepper +--- + +# Plan: Standalone Model & Tool Health Check on `DataDesigner` + +## Problem + +Today, the only way to verify that the models and MCP tools referenced by a +configuration are actually reachable is to start a workload. `DatasetBuilder.build()` +and `DatasetBuilder.build_preview()` both call +`_run_model_health_check_if_needed()` followed by `_run_mcp_tool_check_if_needed()` +as their first action +(`packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py:258`, +`:592`). To find out a credential is wrong, an alias is unregistered, or an MCP +server is down, a user has to invoke `preview()` or `create()` and wait for it +to fail at the same gate. + +`DataDesigner.validate(config_builder)` already covers the *internal* readiness +question — "is this configuration well-formed against my engine components?" — +but there is no symmetric *external* readiness method to ask "are the providers +this configuration depends on actually responsive?". + +## Goals + +1. Expose a public method on the `DataDesigner` interface that runs the same + external-readiness checks as workload startup, with no other side effects + (no artifact directory population, no batches, no profiling). +2. Cover both **models** and **MCP tools** — they are run together at workload + startup, and a user asking "am I good to go?" expects the same coverage. +3. Mirror `validate()` in shape and error policy: takes a config builder, + returns `None`, raises typed engine errors. Add a CLI command alongside the + existing `validate` command. +4. Avoid drift between this method and the workload startup path by extracting + the shared logic into the engine and calling it from both places. + +## Non-goals + +- Reworking the existing health-check probes themselves (still a tiny `"Hello!"` + generation per model; no changes to `ModelRegistry.run_health_check` / + `arun_health_check` semantics). +- Concurrency limits, partial-failure aggregation, or per-alias filtering. The + method is a one-shot pass-through that fails fast on the first error, exactly + like the startup path. These can be follow-ups if requested. +- Touching `skip_health_check` semantics. Models with `skip_health_check=True` + remain skipped, just as they are at startup. + +## Design + +### Naming + +Method: `DataDesigner.check_models(config_builder)`. CLI: `dd check-models`. + +"Models" in this codebase already names the externally-hosted resources a +configuration depends on. MCP tools are coupled to model-using columns — +`_run_mcp_tool_check_if_needed` collects aliases from +`llm_generated_column_configs` (`dataset_builder.py:1354`), so any config with +tools necessarily has at least one model-using column; you cannot have a +"tools-only" config that this method would mis-name. Treating MCP tool +liveness as part of "checking the models" is consistent with how the codebase +already groups them at the startup gate. + +This pairs cleanly with the existing surface: `validate` for internal +readiness, `check_models` for external readiness. + +### Public method on `DataDesigner` + +Add to `packages/data-designer/src/data_designer/interface/data_designer.py`, +adjacent to `validate` (line 533): + +```python +def check_models(self, config_builder: DataDesignerConfigBuilder) -> None: + """Probe every model and MCP tool referenced by the configuration. + + Runs the same readiness checks performed at the start of ``preview`` and + ``create``: a tiny generation against each referenced model alias, plus + a connectivity probe to each referenced MCP tool. Models whose + ``ModelConfig`` has ``skip_health_check=True`` are skipped. + + Args: + config_builder: The DataDesignerConfigBuilder whose column configs + determine which model aliases and tool aliases are probed. + + Returns: + None if every (non-skipped) probe succeeded. + + Raises: + Typed model errors from ``data_designer.engine.models.errors`` + (e.g. ``ModelAuthenticationError``, ``ModelNotFoundError``, + ``ModelAPIConnectionError``) for any failing model probe. + DatasetGenerationError: If a tool alias is referenced but no + ``MCPRegistry`` is configured. + TimeoutError: If async health-check execution exceeds 180 seconds. + """ + resource_provider = self._create_resource_provider("check-models", config_builder) + config = config_builder.build() + run_readiness_check(config.columns, resource_provider) +``` + +### Shared helper module + +Add a new module +`packages/data-designer-engine/src/data_designer/engine/readiness.py`. It +hosts the shared logic so the standalone method and the workload startup +path can never drift: + +```python +# data_designer/engine/readiness.py + +def run_readiness_check( + column_configs: Sequence[ColumnConfig], + resource_provider: ResourceProvider, +) -> None: + """Run model + MCP tool health checks for the given column configs. + + Used by both ``DatasetBuilder.build``/``build_preview`` (at workload start) + and ``DataDesigner.check_models`` (standalone). + """ + _run_model_health_check(column_configs, resource_provider) + _run_mcp_tool_health_check(column_configs, resource_provider) + + +def _run_model_health_check( + column_configs: Sequence[ColumnConfig], + resource_provider: ResourceProvider, +) -> None: + ... # body lifted from DatasetBuilder._run_model_health_check_if_needed + + +def _run_mcp_tool_health_check( + column_configs: Sequence[ColumnConfig], + resource_provider: ResourceProvider, +) -> None: + ... # body lifted from DatasetBuilder._run_mcp_tool_check_if_needed +``` + +The bodies are essentially the existing +`_run_model_health_check_if_needed` (`dataset_builder.py:1330`) and +`_run_mcp_tool_check_if_needed` (`dataset_builder.py:1352`), rewritten to +take `column_configs` and `resource_provider` as arguments instead of reading +`self.single_column_configs` and `self._resource_provider`. + +`run_readiness_check` is the only public symbol the module exports; +`_run_model_health_check` and `_run_mcp_tool_health_check` stay +module-private. This is achievable because the two `DatasetBuilder` call +sites that currently invoke the model check and tool check back-to-back +(`dataset_builder.py:258-259` and `:592-593`) have no logic between them and +both run unconditionally — they collapse cleanly to a single +`run_readiness_check(...)` call: + +```python +# dataset_builder.py:258 and :592 (after the change) +run_readiness_check(self.single_column_configs, self._resource_provider) +``` + +The two `DatasetBuilder._run_*_if_needed` instance methods are removed +entirely (no longer needed as delegating wrappers). Builder tests that +currently patch `_run_mcp_tool_check_if_needed` to suppress MCP probing +during construction-focused tests +(`test_dataset_builder.py:2546, 2596, 2633, 2666`) migrate to patching +`run_readiness_check` at its import site in `dataset_builder.py`. In every +case those tests want the readiness gate as a whole bypassed, so widening +from "MCP-only" to "readiness as a whole" matches their actual intent. + +A new module is preferred over keeping the helpers in `dataset_builder.py`: +the readiness pass is conceptually a separate phase of execution, callable +from outside the dataset builder, and putting it in its own file keeps that +boundary explicit. + +### CLI command + +Add `packages/data-designer/src/data_designer/cli/commands/check_models.py`, +modelled on `validate.py`. Wire it into the CLI in +`packages/data-designer/src/data_designer/cli/__init__.py` next to the +`validate` registration: + +- `dd validate` — internal readiness (config well-formedness) +- `dd check-models` — external readiness (provider liveness) + +### Async dispatch contract + +`run_readiness_check` preserves the existing `DATA_DESIGNER_ASYNC_ENGINE` +switch and 180-second timeout for the model probes. MCP tool probes remain +sync because `MCPRegistry` has no async health-check method today +(`packages/data-designer-engine/src/data_designer/engine/mcp/registry.py:180`). +This matches current startup behavior exactly: the workload-startup pass +also runs MCP probes synchronously regardless of which engine is in use, so +nothing changes in observable behavior with this refactor. Adding async +parity for MCP is tracked as a follow-up (see Follow-ups). + +### Logging + +Both probes already produce informative logs from the registry layer (the +`🩺 Running health checks for models...` line at +`packages/data-designer-engine/src/data_designer/engine/models/registry.py:229`, +plus per-alias `✅ Passed!` / `❌ Failed!` lines). The new method needs only a +single high-level entry log, mirroring how `validate` currently emits one +log line via `compile_data_designer_config`. No additional verbosity. + +--- + +## Files to change + +### `packages/data-designer-engine/` + +| File | Change | +|---|---| +| `src/data_designer/engine/readiness.py` | New module. Public surface: `run_readiness_check`. Module-private: `_run_model_health_check`, `_run_mcp_tool_health_check`. | +| `src/data_designer/engine/dataset_builders/dataset_builder.py` | Remove `_run_model_health_check_if_needed` (`:1330`) and `_run_mcp_tool_check_if_needed` (`:1352`). Replace each pair of call sites at `:258-259` and `:592-593` with a single `run_readiness_check(self.single_column_configs, self._resource_provider)` call. | +| `tests/engine/dataset_builders/test_dataset_builder.py` | The four `patch.object(builder, "_run_mcp_tool_check_if_needed")` sites (`:2546, :2596, :2633, :2666`) migrate to patching `run_readiness_check` at its import site in `dataset_builder`. Existing model-check tests at `:327, :362` migrate similarly. Coverage is preserved (and arguably improved — those tests want the readiness gate bypassed, which is now expressible in one patch instead of two). | +| `tests/engine/test_readiness.py` (new) | Direct tests for `run_readiness_check`: success, model auth failure surfaces verbatim, MCP missing-registry surfaces `DatasetGenerationError`, no-aliases short-circuit, async-engine timeout, `skip_health_check=True` honored. Component functions are exercised through `run_readiness_check` rather than directly, since they're module-private. | + +### `packages/data-designer/` + +| File | Change | +|---|---| +| `src/data_designer/interface/data_designer.py` | Add `check_models` method next to `validate`. Import `run_readiness_check` from `data_designer.engine.readiness`. | +| `src/data_designer/cli/commands/check_models.py` | New file. Mirrors `validate.py`. | +| `src/data_designer/cli/main.py` | Register the new command in the lazy Typer group. | +| `tests/interface/test_data_designer.py` | New tests for `check_models`: success path; surfaces `ModelAuthenticationError`; surfaces `DatasetGenerationError` for missing MCP registry; no-op when no models or tools are referenced; respects `skip_health_check=True`. | +| `tests/cli/commands/test_check_models_command.py` | New file mirroring `test_validate_command.py`. | + +### Docs (handled in a follow-up via the `datadesigner-docs` skill) + +- `docs/concepts/models/model-configs.md` — cross-reference the new method + next to the existing `skip_health_check` discussion (`:112`). +- Data Designer API reference — add `check_models` next to `validate`. +- A short Fern page or section explaining the `validate` / `check_models` + pair as the "pre-flight" surface. + +Documentation is intentionally scoped out of this plan's first PR to keep the +change set focused; the implementation PR will note the docs follow-up. + +--- + +## Validation + +Engine-level tests already cover the registry probes themselves +(`packages/data-designer-engine/tests/engine/models/test_model_registry.py:329-449`). +The work here adds: + +1. **Refactor safety** — existing `DatasetBuilder` health-check tests + (`test_dataset_builder.py:327, 362, 2546, 2596, 2633, 2666`) keep passing + without modification. +2. **New free-function tests** — direct coverage of `run_readiness_check` + and the two component functions against fakes for both registries. +3. **Interface tests** — `DataDesigner.check_models` end-to-end with a + mocked `ResourceProvider`/registries: + - Success path returns `None`. + - Model auth failure surfaces the typed error verbatim. + - Missing MCP registry surfaces `DatasetGenerationError`. + - Configuration with no model-using columns and no tools is a no-op. + - `skip_health_check=True` aliases are not probed. +4. **CLI test** — invocation delegates to the controller, mirroring + `test_validate_command.py`. + +Manual verification before merge: + +```bash +make check-all-fix +make test +``` + +Plus a smoke run with a real (or recorded) provider: + +```bash +dd check-models path/to/config.yaml # exits 0 on success +dd check-models path/to/bad-creds.yaml # exits non-zero with typed error +``` + +--- + +## Follow-ups + +- **Async parity for MCP tool probes.** Today, MCP probes go through + `MCPRegistry.run_health_check` regardless of which engine is in use; there + is no `arun_health_check`. This is a stylistic asymmetry rather than a + correctness gap (the probe runs serially before any concurrent work + begins), but adding async parity would let the readiness pass complete + fully on the async event loop. Worth a tracking issue if/when MCP usage + grows. +- **Per-alias filtering / partial reporting.** The current method is a + one-shot fail-fast pass-through. If users want "tell me everything that's + broken in one go" or "only check this subset of models," that's an + additive API on top of the same engine helper.