Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def interpolate_env(self, conf: RunConfigurationT):
password=interpolator.interpolate_or_error(conf.registry_auth.password),
)
if isinstance(conf, ServiceConfiguration):
for probe in conf.probes:
for probe in conf.probes or []:
for header in probe.headers:
header.value = interpolator.interpolate_or_error(header.value)
if probe.url:
Expand Down
39 changes: 36 additions & 3 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
DEFAULT_PROBE_METHOD = "get"
MAX_PROBE_URL_LEN = 2048
DEFAULT_REPLICA_GROUP_NAME = "0"
DEFAULT_MODEL_PROBE_TIMEOUT = 30
DEFAULT_MODEL_PROBE_URL = "/v1/chat/completions"


class RunConfigurationType(str, Enum):
Expand Down Expand Up @@ -851,9 +853,9 @@ class ServiceConfigurationParams(CoreModel):
] = None
rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = []
probes: Annotated[
list[ProbeConfig],
Optional[list[ProbeConfig]],
Field(description="List of probes used to determine job health"),
] = []
] = None # None = omitted (may get default when model is set); [] = explicit empty

replicas: Annotated[
Optional[Union[List[ReplicaGroup], Range[int]]],
Expand Down Expand Up @@ -895,7 +897,9 @@ def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
return v

@validator("probes")
def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
def validate_probes(cls, v: Optional[list[ProbeConfig]]) -> Optional[list[ProbeConfig]]:
if v is None:
return v
if has_duplicates(v):
# Using a custom validator instead of Field(unique_items=True) to avoid Pydantic bug:
# https://github.com/pydantic/pydantic/issues/3765
Expand Down Expand Up @@ -932,6 +936,35 @@ def validate_replicas(
)
return v

@root_validator()
def set_default_probes_for_model(cls, values):
model = values.get("model")
probes = values.get("probes")
if model is not None and probes is None:
body = orjson.dumps(
{
"model": model.name,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
}
).decode("utf-8")
values["probes"] = [
ProbeConfig(
type="http",
method="post",
url=DEFAULT_MODEL_PROBE_URL,
headers=[
HTTPHeaderSpec(name="Content-Type", value="application/json"),
],
body=body,
timeout=DEFAULT_MODEL_PROBE_TIMEOUT,
)
]
elif probes is None:
# Probes omitted and model not set: normalize to empty list for downstream.
values["probes"] = []
return values

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _service_port(self) -> Optional[int]:

def _probes(self) -> list[ProbeSpec]:
if isinstance(self.run_spec.configuration, ServiceConfiguration):
return list(map(_probe_config_to_spec, self.run_spec.configuration.probes))
return list(map(_probe_config_to_spec, self.run_spec.configuration.probes or []))
return []


Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/services/runs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def validate_run_spec_and_set_defaults(
raise ServerClientError(
"Scheduled services with autoscaling to zero are not supported"
)
if len(run_spec.configuration.probes) > settings.MAX_PROBES_PER_JOB:
if len(run_spec.configuration.probes or []) > settings.MAX_PROBES_PER_JOB:
raise ServerClientError(
f"Cannot configure more than {settings.MAX_PROBES_PER_JOB} probes"
)
if any(
p.timeout is not None and p.timeout > settings.MAX_PROBE_TIMEOUT
for p in run_spec.configuration.probes
for p in (run_spec.configuration.probes or [])
):
raise ServerClientError(
f"Probe timeout cannot be longer than {settings.MAX_PROBE_TIMEOUT}s"
Expand Down
45 changes: 45 additions & 0 deletions src/tests/_internal/core/models/test_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
DEFAULT_MODEL_PROBE_TIMEOUT,
DEFAULT_MODEL_PROBE_URL,
DevEnvironmentConfigurationParams,
RepoSpec,
parse_run_configuration,
Expand All @@ -13,6 +15,49 @@


class TestParseConfiguration:
def test_service_model_sets_default_probes_when_probes_omitted(self):
conf = {
"type": "service",
"commands": ["python3 -m http.server"],
"port": 8000,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
}
parsed = parse_run_configuration(conf)
assert len(parsed.probes) == 1
probe = parsed.probes[0]
assert probe.type == "http"
assert probe.method == "post"
assert probe.url == DEFAULT_MODEL_PROBE_URL
assert probe.timeout == DEFAULT_MODEL_PROBE_TIMEOUT
assert len(probe.headers) == 1
assert probe.headers[0].name == "Content-Type"
assert probe.headers[0].value == "application/json"
assert "meta-llama/Meta-Llama-3.1-8B-Instruct" in (probe.body or "")
assert "max_tokens" in (probe.body or "")

def test_service_model_does_not_override_explicit_probes(self):
conf = {
"type": "service",
"commands": ["python3 -m http.server"],
"port": 8000,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"probes": [{"type": "http", "url": "/health"}],
}
parsed = parse_run_configuration(conf)
assert len(parsed.probes) == 1
assert parsed.probes[0].url == "/health"

def test_service_model_explicit_empty_probes_no_default(self):
conf = {
"type": "service",
"commands": ["python3 -m http.server"],
"port": 8000,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"probes": [],
}
parsed = parse_run_configuration(conf)
assert len(parsed.probes) == 0

def test_services_replicas_and_scaling(self):
def test_conf(replicas: Any, scaling: Optional[Any] = None):
conf = {
Expand Down