diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index fc76fe43ed..1077eff8a9 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -354,7 +354,7 @@ def interpolate_env(self, conf: RunConfigurationT): password=interpolator.interpolate_or_error(conf.registry_auth.password), ) if isinstance(conf, ServiceConfiguration): - for probe in conf.probes: + for probe in conf.probes or []: for header in probe.headers: header.value = interpolator.interpolate_or_error(header.value) if probe.url: diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 3b2c7812b9..f6c1b385a2 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -56,6 +56,8 @@ DEFAULT_PROBE_METHOD = "get" MAX_PROBE_URL_LEN = 2048 DEFAULT_REPLICA_GROUP_NAME = "0" +DEFAULT_MODEL_PROBE_TIMEOUT = 30 +DEFAULT_MODEL_PROBE_URL = "/v1/chat/completions" class RunConfigurationType(str, Enum): @@ -851,9 +853,9 @@ class ServiceConfigurationParams(CoreModel): ] = None rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = [] probes: Annotated[ - list[ProbeConfig], + Optional[list[ProbeConfig]], Field(description="List of probes used to determine job health"), - ] = [] + ] = None # None = omitted (may get default when model is set); [] = explicit empty replicas: Annotated[ Optional[Union[List[ReplicaGroup], Range[int]]], @@ -895,7 +897,9 @@ def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: return v @validator("probes") - def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: + def validate_probes(cls, v: Optional[list[ProbeConfig]]) -> Optional[list[ProbeConfig]]: + if v is None: + return v if has_duplicates(v): # Using a custom validator instead of Field(unique_items=True) to avoid Pydantic bug: # https://github.com/pydantic/pydantic/issues/3765 @@ -932,6 +936,35 @@ def validate_replicas( ) return v + @root_validator() + def set_default_probes_for_model(cls, values): + model = values.get("model") + probes = values.get("probes") + if model is not None and probes is None: + body = orjson.dumps( + { + "model": model.name, + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 1, + } + ).decode("utf-8") + values["probes"] = [ + ProbeConfig( + type="http", + method="post", + url=DEFAULT_MODEL_PROBE_URL, + headers=[ + HTTPHeaderSpec(name="Content-Type", value="application/json"), + ], + body=body, + timeout=DEFAULT_MODEL_PROBE_TIMEOUT, + ) + ] + elif probes is None: + # Probes omitted and model not set: normalize to empty list for downstream. + values["probes"] = [] + return values + @root_validator() def validate_scaling(cls, values): scaling = values.get("scaling") diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index df6738a774..44579bf81c 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -394,7 +394,7 @@ def _service_port(self) -> Optional[int]: def _probes(self) -> list[ProbeSpec]: if isinstance(self.run_spec.configuration, ServiceConfiguration): - return list(map(_probe_config_to_spec, self.run_spec.configuration.probes)) + return list(map(_probe_config_to_spec, self.run_spec.configuration.probes or [])) return [] diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index db81eb724a..ad2fcef1ff 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -94,13 +94,13 @@ def validate_run_spec_and_set_defaults( raise ServerClientError( "Scheduled services with autoscaling to zero are not supported" ) - if len(run_spec.configuration.probes) > settings.MAX_PROBES_PER_JOB: + if len(run_spec.configuration.probes or []) > settings.MAX_PROBES_PER_JOB: raise ServerClientError( f"Cannot configure more than {settings.MAX_PROBES_PER_JOB} probes" ) if any( p.timeout is not None and p.timeout > settings.MAX_PROBE_TIMEOUT - for p in run_spec.configuration.probes + for p in (run_spec.configuration.probes or []) ): raise ServerClientError( f"Probe timeout cannot be longer than {settings.MAX_PROBE_TIMEOUT}s" diff --git a/src/tests/_internal/core/models/test_configurations.py b/src/tests/_internal/core/models/test_configurations.py index 65eec62642..1ff025ea2f 100644 --- a/src/tests/_internal/core/models/test_configurations.py +++ b/src/tests/_internal/core/models/test_configurations.py @@ -5,6 +5,8 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.common import RegistryAuth from dstack._internal.core.models.configurations import ( + DEFAULT_MODEL_PROBE_TIMEOUT, + DEFAULT_MODEL_PROBE_URL, DevEnvironmentConfigurationParams, RepoSpec, parse_run_configuration, @@ -13,6 +15,49 @@ class TestParseConfiguration: + def test_service_model_sets_default_probes_when_probes_omitted(self): + conf = { + "type": "service", + "commands": ["python3 -m http.server"], + "port": 8000, + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + } + parsed = parse_run_configuration(conf) + assert len(parsed.probes) == 1 + probe = parsed.probes[0] + assert probe.type == "http" + assert probe.method == "post" + assert probe.url == DEFAULT_MODEL_PROBE_URL + assert probe.timeout == DEFAULT_MODEL_PROBE_TIMEOUT + assert len(probe.headers) == 1 + assert probe.headers[0].name == "Content-Type" + assert probe.headers[0].value == "application/json" + assert "meta-llama/Meta-Llama-3.1-8B-Instruct" in (probe.body or "") + assert "max_tokens" in (probe.body or "") + + def test_service_model_does_not_override_explicit_probes(self): + conf = { + "type": "service", + "commands": ["python3 -m http.server"], + "port": 8000, + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "probes": [{"type": "http", "url": "/health"}], + } + parsed = parse_run_configuration(conf) + assert len(parsed.probes) == 1 + assert parsed.probes[0].url == "/health" + + def test_service_model_explicit_empty_probes_no_default(self): + conf = { + "type": "service", + "commands": ["python3 -m http.server"], + "port": 8000, + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "probes": [], + } + parsed = parse_run_configuration(conf) + assert len(parsed.probes) == 0 + def test_services_replicas_and_scaling(self): def test_conf(replicas: Any, scaling: Optional[Any] = None): conf = {