Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->

async def get_or_add_gateway_connection(
session: AsyncSession, gateway_id: uuid.UUID
) -> GatewayConnection:
) -> tuple[GatewayModel, GatewayConnection]:
gateway = await session.get(GatewayModel, gateway_id)
if gateway is None:
raise GatewayError("Gateway not found")
Expand All @@ -460,7 +460,7 @@ async def get_or_add_gateway_connection(
"Failed to connect to gateway %s: %s", gateway.gateway_compute.ip_address, e
)
raise GatewayError("Failed to connect to gateway")
return conn
return gateway, conn


async def init_gateways(session: AsyncSession):
Expand Down
63 changes: 47 additions & 16 deletions src/dstack/_internal/server/services/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec
from dstack._internal.server import settings
from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel
from dstack._internal.server.services import events
from dstack._internal.server.services.gateways import (
get_gateway_configuration,
get_or_add_gateway_connection,
Expand Down Expand Up @@ -114,7 +115,7 @@ async def _register_service_in_gateway(
domain = service_spec.get_domain()
assert domain is not None

conn = await get_or_add_gateway_connection(session, gateway.id)
_, conn = await get_or_add_gateway_connection(session, gateway.id)
try:
logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url)
async with conn.client() as client:
Expand All @@ -131,13 +132,21 @@ async def _register_service_in_gateway(
ssh_private_key=run_model.project.ssh_private_key,
router=router,
)
logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url)
except SSHError:
raise ServerClientError("Gateway tunnel is not working")
except httpx.RequestError as e:
logger.debug("Gateway request failed", exc_info=True)
raise GatewayError(f"Gateway is not working: {e!r}")

events.emit(
session,
"Service registered in gateway",
actor=events.SystemActor(),
targets=[
events.Target.from_model(run_model),
events.Target.from_model(gateway),
],
)
return service_spec


Expand Down Expand Up @@ -193,8 +202,9 @@ async def register_replica(
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
):
gateway = None
if gateway_id is not None:
conn = await get_or_add_gateway_connection(session, gateway_id)
gateway, conn = await get_or_add_gateway_connection(session, gateway_id)
job_submission = jobs_services.job_model_to_job_submission(job_model)
try:
logger.debug("%s: registering replica for service %s", fmt(job_model), run.id.hex)
Expand Down Expand Up @@ -225,17 +235,21 @@ async def register_replica(
else:
raise
job_model.registered = True
logger.info(
"%s: service replica registered to receive requests, gateway=%s",
fmt(job_model),
gateway_id is not None,
targets = [events.Target.from_model(job_model)]
if gateway is not None:
targets.append(events.Target.from_model(gateway))
events.emit(
session,
"Service replica registered to receive requests",
actor=events.SystemActor(),
targets=targets,
)


async def unregister_service(session: AsyncSession, run_model: RunModel):
if run_model.gateway_id is None: # in-server proxy
return
conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
res = await session.execute(
select(ProjectModel).where(ProjectModel.id == run_model.project_id)
)
Expand All @@ -247,24 +261,37 @@ async def unregister_service(session: AsyncSession, run_model: RunModel):
project=project.name,
run_name=run_model.run_name,
)
logger.debug("%s: service is unregistered", fmt(run_model))
event_msg = "Service unregistered from gateway"
except GatewayError as e:
# ignore if service is not registered
logger.warning("%s: unregistering service: %s", fmt(run_model), e)
event_msg = f"Gateway error when unregistering service: {e}"
except (httpx.RequestError, SSHError) as e:
logger.debug("Gateway request failed", exc_info=True)
raise GatewayError(repr(e))
events.emit(
session,
event_msg,
actor=events.SystemActor(),
targets=[
events.Target.from_model(run_model),
events.Target.from_model(gateway),
],
)


async def unregister_replica(session: AsyncSession, job_model: JobModel):
if not job_model.registered: # non-services and unregistered service replicas
return
res = await session.execute(
select(RunModel)
.where(RunModel.id == job_model.run_id)
.options(joinedload(RunModel.project).joinedload(ProjectModel.backends))
.options(joinedload(RunModel.project))
)
run_model = res.unique().scalar_one()
gateway = None
if run_model.gateway_id is not None:
conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
try:
logger.debug(
"%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex
Expand All @@ -282,10 +309,14 @@ async def unregister_replica(session: AsyncSession, job_model: JobModel):
logger.debug("Gateway request failed", exc_info=True)
raise GatewayError(repr(e))
job_model.registered = False
logger.info(
"%s: service replica unregistered from receiving requests, gateway=%s",
fmt(job_model),
run_model.gateway_id is not None,
targets = [events.Target.from_model(job_model)]
if gateway is not None:
targets.append(events.Target.from_model(gateway))
events.emit(
session,
"Service replica unregistered from receiving requests",
actor=events.SystemActor(),
targets=targets,
)


Expand Down Expand Up @@ -314,7 +345,7 @@ async def update_service_desired_replica_count(
) -> None:
stats = None
if run_model.gateway_id is not None:
conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
_, conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
stats = await conn.get_stats(run_model.project.name, run_model.run_name)
replica_groups = configuration.replica_groups
desired_replica_counts = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,11 @@ async def test_registers_service_replica_immediately_if_no_probes(
await session.refresh(job)
assert job.status == JobStatus.RUNNING
assert job.registered
events = await list_events(session)
assert {e.message for e in events} == {
"Job status changed PULLING -> RUNNING",
"Service replica registered to receive requests",
}

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down Expand Up @@ -1104,7 +1109,14 @@ async def test_registers_service_replica_only_after_probes_pass(
await process_running_jobs()

await session.refresh(job)
assert job.registered == expect_to_register
events = await list_events(session)
if expect_to_register:
assert job.registered
assert len(events) == 1
assert events[0].message == "Service replica registered to receive requests"
else:
assert not job.registered
assert not events


class TestPatchBaseImageForAwsEfa:
Expand Down
12 changes: 11 additions & 1 deletion src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
get_fleet_spec,
get_job_provisioning_data,
get_run_spec,
list_events,
)
from dstack._internal.server.testing.matchers import SomeUUID4Str
from tests._internal.server.background.tasks.test_process_running_jobs import settings
Expand Down Expand Up @@ -2085,41 +2086,47 @@ def mock_gateway_connections(self) -> Generator[None, None, None]:
"specified_gateway_in_run_conf",
"expected_service_url",
"expected_model_url",
"is_gateway",
),
[
pytest.param(
[("default-gateway", True), ("non-default-gateway", False)],
None,
"https://test-service.default-gateway.example",
"https://gateway.default-gateway.example",
True,
id="submits-to-default-gateway",
),
pytest.param(
[("default-gateway", True), ("non-default-gateway", False)],
True,
"https://test-service.default-gateway.example",
"https://gateway.default-gateway.example",
True,
id="submits-to-default-gateway-when-gateway-true",
),
pytest.param(
[("default-gateway", True), ("non-default-gateway", False)],
"non-default-gateway",
"https://test-service.non-default-gateway.example",
"https://gateway.non-default-gateway.example",
True,
id="submits-to-specified-gateway",
),
pytest.param(
[("non-default-gateway", False)],
None,
"/proxy/services/test-project/test-service/",
"/proxy/models/test-project/",
False,
id="submits-in-server-when-no-default-gateway",
),
pytest.param(
[("default-gateway", True)],
False,
"/proxy/services/test-project/test-service/",
"/proxy/models/test-project/",
False,
id="submits-in-server-when-specified",
),
],
Expand All @@ -2130,9 +2137,10 @@ async def test_submit_to_correct_proxy(
session: AsyncSession,
client: AsyncClient,
existing_gateways: List[Tuple[str, bool]],
specified_gateway_in_run_conf: str,
specified_gateway_in_run_conf: Union[str, bool, None],
expected_service_url: str,
expected_model_url: str,
is_gateway: bool,
) -> None:
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user, name="test-project")
Expand Down Expand Up @@ -2171,6 +2179,8 @@ async def test_submit_to_correct_proxy(
assert response.status_code == 200
assert response.json()["service"]["url"] == expected_service_url
assert response.json()["service"]["model"]["base_url"] == expected_model_url
events = await list_events(session)
assert ("Service registered in gateway" in {e.message for e in events}) == is_gateway

@pytest.mark.asyncio
async def test_return_error_if_specified_gateway_not_exists(
Expand Down