From acd59a0228289ef41425c2177beb7f776ab6e3ba Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Thu, 29 Jan 2026 19:40:25 +0100 Subject: [PATCH] Add service and replica registration events - Service registered in gateway - Service unregistered from gateway - Service replica registered to receive requests - Service replica unregistered from receiving requests --- .../server/services/gateways/__init__.py | 4 +- .../server/services/services/__init__.py | 63 ++++++++++++++----- .../tasks/test_process_running_jobs.py | 14 ++++- .../_internal/server/routers/test_runs.py | 12 +++- 4 files changed, 73 insertions(+), 20 deletions(-) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index bff20466a..ab89c2a7c 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -444,7 +444,7 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> async def get_or_add_gateway_connection( session: AsyncSession, gateway_id: uuid.UUID -) -> GatewayConnection: +) -> tuple[GatewayModel, GatewayConnection]: gateway = await session.get(GatewayModel, gateway_id) if gateway is None: raise GatewayError("Gateway not found") @@ -460,7 +460,7 @@ async def get_or_add_gateway_connection( "Failed to connect to gateway %s: %s", gateway.gateway_compute.ip_address, e ) raise GatewayError("Failed to connect to gateway") - return conn + return gateway, conn async def init_gateways(session: AsyncSession): diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 45bb1fe0f..06aa5b0ef 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -29,6 +29,7 @@ from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel +from dstack._internal.server.services import events from dstack._internal.server.services.gateways import ( get_gateway_configuration, get_or_add_gateway_connection, @@ -114,7 +115,7 @@ async def _register_service_in_gateway( domain = service_spec.get_domain() assert domain is not None - conn = await get_or_add_gateway_connection(session, gateway.id) + _, conn = await get_or_add_gateway_connection(session, gateway.id) try: logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) async with conn.client() as client: @@ -131,13 +132,21 @@ async def _register_service_in_gateway( ssh_private_key=run_model.project.ssh_private_key, router=router, ) - logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) except SSHError: raise ServerClientError("Gateway tunnel is not working") except httpx.RequestError as e: logger.debug("Gateway request failed", exc_info=True) raise GatewayError(f"Gateway is not working: {e!r}") + events.emit( + session, + "Service registered in gateway", + actor=events.SystemActor(), + targets=[ + events.Target.from_model(run_model), + events.Target.from_model(gateway), + ], + ) return service_spec @@ -193,8 +202,9 @@ async def register_replica( ssh_head_proxy: Optional[SSHConnectionParams], ssh_head_proxy_private_key: Optional[str], ): + gateway = None if gateway_id is not None: - conn = await get_or_add_gateway_connection(session, gateway_id) + gateway, conn = await get_or_add_gateway_connection(session, gateway_id) job_submission = jobs_services.job_model_to_job_submission(job_model) try: logger.debug("%s: registering replica for service %s", fmt(job_model), run.id.hex) @@ -225,17 +235,21 @@ async def register_replica( else: raise job_model.registered = True - logger.info( - "%s: service replica registered to receive requests, gateway=%s", - fmt(job_model), - gateway_id is not None, + targets = [events.Target.from_model(job_model)] + if gateway is not None: + targets.append(events.Target.from_model(gateway)) + events.emit( + session, + "Service replica registered to receive requests", + actor=events.SystemActor(), + targets=targets, ) async def unregister_service(session: AsyncSession, run_model: RunModel): if run_model.gateway_id is None: # in-server proxy return - conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) res = await session.execute( select(ProjectModel).where(ProjectModel.id == run_model.project_id) ) @@ -247,24 +261,37 @@ async def unregister_service(session: AsyncSession, run_model: RunModel): project=project.name, run_name=run_model.run_name, ) - logger.debug("%s: service is unregistered", fmt(run_model)) + event_msg = "Service unregistered from gateway" except GatewayError as e: # ignore if service is not registered logger.warning("%s: unregistering service: %s", fmt(run_model), e) + event_msg = f"Gateway error when unregistering service: {e}" except (httpx.RequestError, SSHError) as e: logger.debug("Gateway request failed", exc_info=True) raise GatewayError(repr(e)) + events.emit( + session, + event_msg, + actor=events.SystemActor(), + targets=[ + events.Target.from_model(run_model), + events.Target.from_model(gateway), + ], + ) async def unregister_replica(session: AsyncSession, job_model: JobModel): + if not job_model.registered: # non-services and unregistered service replicas + return res = await session.execute( select(RunModel) .where(RunModel.id == job_model.run_id) - .options(joinedload(RunModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(RunModel.project)) ) run_model = res.unique().scalar_one() + gateway = None if run_model.gateway_id is not None: - conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) try: logger.debug( "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex @@ -282,10 +309,14 @@ async def unregister_replica(session: AsyncSession, job_model: JobModel): logger.debug("Gateway request failed", exc_info=True) raise GatewayError(repr(e)) job_model.registered = False - logger.info( - "%s: service replica unregistered from receiving requests, gateway=%s", - fmt(job_model), - run_model.gateway_id is not None, + targets = [events.Target.from_model(job_model)] + if gateway is not None: + targets.append(events.Target.from_model(gateway)) + events.emit( + session, + "Service replica unregistered from receiving requests", + actor=events.SystemActor(), + targets=targets, ) @@ -314,7 +345,7 @@ async def update_service_desired_replica_count( ) -> None: stats = None if run_model.gateway_id is not None: - conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) stats = await conn.get_stats(run_model.project.name, run_model.run_name) replica_groups = configuration.replica_groups desired_replica_counts = {} diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 9e318866c..12edeec20 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -1014,6 +1014,11 @@ async def test_registers_service_replica_immediately_if_no_probes( await session.refresh(job) assert job.status == JobStatus.RUNNING assert job.registered + events = await list_events(session) + assert {e.message for e in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -1104,7 +1109,14 @@ async def test_registers_service_replica_only_after_probes_pass( await process_running_jobs() await session.refresh(job) - assert job.registered == expect_to_register + events = await list_events(session) + if expect_to_register: + assert job.registered + assert len(events) == 1 + assert events[0].message == "Service replica registered to receive requests" + else: + assert not job.registered + assert not events class TestPatchBaseImageForAwsEfa: diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 70ab54bd1..d24382f35 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -63,6 +63,7 @@ get_fleet_spec, get_job_provisioning_data, get_run_spec, + list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str from tests._internal.server.background.tasks.test_process_running_jobs import settings @@ -2085,6 +2086,7 @@ def mock_gateway_connections(self) -> Generator[None, None, None]: "specified_gateway_in_run_conf", "expected_service_url", "expected_model_url", + "is_gateway", ), [ pytest.param( @@ -2092,6 +2094,7 @@ def mock_gateway_connections(self) -> Generator[None, None, None]: None, "https://test-service.default-gateway.example", "https://gateway.default-gateway.example", + True, id="submits-to-default-gateway", ), pytest.param( @@ -2099,6 +2102,7 @@ def mock_gateway_connections(self) -> Generator[None, None, None]: True, "https://test-service.default-gateway.example", "https://gateway.default-gateway.example", + True, id="submits-to-default-gateway-when-gateway-true", ), pytest.param( @@ -2106,6 +2110,7 @@ def mock_gateway_connections(self) -> Generator[None, None, None]: "non-default-gateway", "https://test-service.non-default-gateway.example", "https://gateway.non-default-gateway.example", + True, id="submits-to-specified-gateway", ), pytest.param( @@ -2113,6 +2118,7 @@ def mock_gateway_connections(self) -> Generator[None, None, None]: None, "/proxy/services/test-project/test-service/", "/proxy/models/test-project/", + False, id="submits-in-server-when-no-default-gateway", ), pytest.param( @@ -2120,6 +2126,7 @@ def mock_gateway_connections(self) -> Generator[None, None, None]: False, "/proxy/services/test-project/test-service/", "/proxy/models/test-project/", + False, id="submits-in-server-when-specified", ), ], @@ -2130,9 +2137,10 @@ async def test_submit_to_correct_proxy( session: AsyncSession, client: AsyncClient, existing_gateways: List[Tuple[str, bool]], - specified_gateway_in_run_conf: str, + specified_gateway_in_run_conf: Union[str, bool, None], expected_service_url: str, expected_model_url: str, + is_gateway: bool, ) -> None: user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user, name="test-project") @@ -2171,6 +2179,8 @@ async def test_submit_to_correct_proxy( assert response.status_code == 200 assert response.json()["service"]["url"] == expected_service_url assert response.json()["service"]["model"]["base_url"] == expected_model_url + events = await list_events(session) + assert ("Service registered in gateway" in {e.message for e in events}) == is_gateway @pytest.mark.asyncio async def test_return_error_if_specified_gateway_not_exists(