Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

### Features Added

### Breaking Changes

### Bugs Fixed
- Error source classification headers: All HTTP error responses now include `x-platform-error-source` with a value of `user`, `platform`, or `upstream` to indicate which component caused the error. Developer handler exceptions and missing handler registrations are classified as `upstream`. Exceptions tagged with the platform error tag are classified as `platform` and additionally include `x-platform-error-detail` with truncated exception details (max 2048 characters) for diagnostics.
- SSE keep-alive comment frames (`: keep-alive`) are now interleaved into `text/event-stream` responses returned by invoke handlers when the `SSE_KEEPALIVE_INTERVAL` environment variable is set to a positive integer (resolved via `AgentConfig.sse_keepalive_interval`). This prevents idle SSE connections from being closed by intermediate proxies and brings the invocations server to parity with the responses server.


### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,17 +350,25 @@ def _wrap_streaming_response(
response: StreamingResponse,
otel_span: Any,
) -> StreamingResponse:
"""Wrap a streaming response's body iterator with span lifecycle and context.

Two layers of wrapping are applied:

1. **Inner (tracing):** ``trace_stream`` wraps the body iterator so
the OTel span covers the full streaming duration and is ended
when iteration completes.
2. **Outer (context):** A second async generator re-attaches the span
as the current context for the duration of streaming, so that
child spans created by user handler code (e.g. Agent Framework)
are correctly parented under this span.
"""Wrap a streaming response's body iterator with span lifecycle, context, and SSE keep-alive.

Up to three layers of wrapping are applied (from innermost to
outermost):

1. **Tracing:** ``trace_stream`` wraps the body iterator so the OTel
span covers the full streaming duration and is ended when
iteration completes. Skipped when *otel_span* is ``None``.
2. **Context:** An async generator re-attaches the span as the
current context for the duration of streaming, so child spans
created by user handler code (e.g. Agent Framework) are
correctly parented under this span. Skipped when *otel_span*
is ``None``.
3. **SSE keep-alive:** When the response media type is
``text/event-stream`` and ``AgentConfig.sse_keepalive_interval``
is positive (driven by the ``SSE_KEEPALIVE_INTERVAL`` env var
set on the container), :meth:`AgentServerHost.sse_keepalive_stream`
interleaves ``: keep-alive`` SSE comment frames into idle
streams so intermediaries do not close the connection.

:param response: The ``StreamingResponse`` returned by the user handler.
:type response: ~starlette.responses.StreamingResponse
Expand All @@ -369,23 +377,31 @@ def _wrap_streaming_response(
:return: The same response object, with its body_iterator replaced.
:rtype: ~starlette.responses.StreamingResponse
"""
if otel_span is None:
return response

# Inner wrap: trace_stream ends the span when iteration completes.
traced = trace_stream(response.body_iterator, otel_span)

# Outer wrap: re-attach span as current context during streaming
# so child spans are correctly parented.
async def _iter_with_context(): # type: ignore[return-value]
token = set_current_span(otel_span)
try:
async for chunk in traced:
yield chunk
finally:
detach_context(token)
if otel_span is not None:
# Inner wrap: trace_stream ends the span when iteration completes.
traced = trace_stream(response.body_iterator, otel_span)

# Middle wrap: re-attach span as current context during streaming
# so child spans are correctly parented.
async def _iter_with_context(): # type: ignore[return-value]
token = set_current_span(otel_span)
try:
async for chunk in traced:
yield chunk
finally:
detach_context(token)

response.body_iterator = _iter_with_context()

# Outer wrap: interleave SSE keep-alive frames for text/event-stream
# responses when the platform has configured a positive interval via
# the SSE_KEEPALIVE_INTERVAL env var (resolved by AgentConfig).
keepalive_interval = self.config.sse_keepalive_interval
if keepalive_interval > 0 and response.media_type == "text/event-stream":
response.body_iterator = AgentServerHost.sse_keepalive_stream(
response.body_iterator, keepalive_interval
)

response.body_iterator = _iter_with_context()
return response

# ------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
"""Tests for SSE keep-alive interleaving on the invocations server.

The invocations server uses the ``SSE_KEEPALIVE_INTERVAL`` environment
variable (resolved via :class:`AgentConfig`) to drive keep-alive frame
injection. These tests exercise both the env-var-driven path and the
default-disabled path, and also verify that keep-alive is only injected
for ``text/event-stream`` responses (not for arbitrary streaming
content types such as NDJSON).
"""
import asyncio

import pytest
from httpx import ASGITransport, AsyncClient
from starlette.requests import Request
from starlette.responses import StreamingResponse

from azure.ai.agentserver.invocations import InvocationAgentServerHost


def _make_slow_sse_agent(delay_seconds: float = 0.6, event_count: int = 2) -> InvocationAgentServerHost:
"""Agent whose invoke handler yields SSE events spaced by *delay_seconds*."""
app = InvocationAgentServerHost()

@app.invoke_handler
async def handle(_request: Request) -> StreamingResponse:
async def _events():
for i in range(event_count):
if i > 0:
await asyncio.sleep(delay_seconds)
yield f"event: msg\ndata: {{\"i\": {i}}}\n\n".encode("utf-8")

return StreamingResponse(_events(), media_type="text/event-stream")

return app


def _make_slow_ndjson_agent(delay_seconds: float = 0.6, event_count: int = 2) -> InvocationAgentServerHost:
"""Agent whose invoke handler streams NDJSON (not SSE) with delays."""
app = InvocationAgentServerHost()

@app.invoke_handler
async def handle(_request: Request) -> StreamingResponse:
async def _events():
for i in range(event_count):
if i > 0:
await asyncio.sleep(delay_seconds)
yield f'{{"i": {i}}}\n'.encode("utf-8")

return StreamingResponse(_events(), media_type="application/x-ndjson")

return app


def _parse_lines(text: str) -> list[str]:
return text.splitlines()
Comment on lines +57 to +58


# ---------------------------------------------------------------------------
# Default (env var unset) — no keep-alive frames are emitted
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_sse_keepalive_disabled_by_default(monkeypatch):
"""With SSE_KEEPALIVE_INTERVAL unset, no ``: keep-alive`` lines appear."""
monkeypatch.delenv("SSE_KEEPALIVE_INTERVAL", raising=False)
app = _make_slow_sse_agent(delay_seconds=0.4, event_count=2)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as client:
resp = await client.post("/invocations", content=b"")
assert resp.status_code == 200
lines = _parse_lines(resp.text)

keepalive_lines = [line for line in lines if line.startswith(": keep-alive")]
assert keepalive_lines == []


# ---------------------------------------------------------------------------
# Env-var driven — keep-alive frames are injected into idle SSE streams
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_sse_keepalive_interleaves_frames_when_env_var_set(monkeypatch):
"""When SSE_KEEPALIVE_INTERVAL is set, ``: keep-alive`` frames appear
during gaps between handler events."""
monkeypatch.setenv("SSE_KEEPALIVE_INTERVAL", "1")
# Construct the host AFTER setting the env var so AgentConfig.from_env()
# picks up the value.
app = _make_slow_sse_agent(delay_seconds=2.5, event_count=2)
Comment on lines +87 to +93
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as client:
resp = await client.post("/invocations", content=b"")
assert resp.status_code == 200
lines = _parse_lines(resp.text)

keepalive_lines = [line for line in lines if line.startswith(": keep-alive")]
assert len(keepalive_lines) >= 1, f"Expected at least one keep-alive comment, got lines={lines!r}"
# Original handler events are still present and intact.
assert any(line == "event: msg" for line in lines)
assert any(line.startswith("data:") for line in lines)


# ---------------------------------------------------------------------------
# Keep-alive must not be applied to non-SSE streaming responses
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_sse_keepalive_not_applied_to_non_sse_streams(monkeypatch):
"""Keep-alive comment frames must not be injected into NDJSON streams
even when SSE_KEEPALIVE_INTERVAL is set."""
monkeypatch.setenv("SSE_KEEPALIVE_INTERVAL", "1")
app = _make_slow_ndjson_agent(delay_seconds=2.5, event_count=2)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as client:
resp = await client.post("/invocations", content=b"")
assert resp.status_code == 200
body = resp.text

assert ": keep-alive" not in body
Loading