Skip to content

Commit ecbb2ef

Browse files
committed
Add LLM retry logic and tests for providers/sessions
- Base LLMProvider now retries on 429/5xx/timeout/connect errors (2 retries, exponential backoff) - Providers implement _generate(), base class wraps with retry in generate() - New test_llm.py: provider selection, timeout config, retry/exhaust scenarios - New test_sessions.py: session CRUD, enqueue, TTL expiry, GC sweep - Test count: 42 → 63
1 parent 52074a1 commit ecbb2ef

6 files changed

Lines changed: 314 additions & 3 deletions

File tree

app/llm/base.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,51 @@
1+
import asyncio
2+
import logging
13
from typing import Any, Dict, List
24

5+
import httpx
6+
7+
logger = logging.getLogger(__name__)
8+
9+
# Retryable HTTP status codes (server errors + rate limiting)
10+
_RETRYABLE_STATUS = {429, 500, 502, 503, 504}
11+
12+
LLM_MAX_RETRIES = 2
13+
LLM_RETRY_BASE_DELAY = 1.0 # seconds
14+
315

416
class LLMProvider:
517
async def generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
18+
last_exc: Exception | None = None
19+
for attempt in range(1 + LLM_MAX_RETRIES):
20+
try:
21+
return await self._generate(messages)
22+
except httpx.HTTPStatusError as exc:
23+
last_exc = exc
24+
if exc.response.status_code not in _RETRYABLE_STATUS:
25+
raise
26+
if attempt < LLM_MAX_RETRIES:
27+
delay = LLM_RETRY_BASE_DELAY * (2**attempt)
28+
logger.warning(
29+
"LLM request failed (%s), retrying in %.1fs (attempt %d/%d)",
30+
exc.response.status_code,
31+
delay,
32+
attempt + 1,
33+
LLM_MAX_RETRIES,
34+
)
35+
await asyncio.sleep(delay)
36+
except (httpx.TimeoutException, httpx.ConnectError) as exc:
37+
last_exc = exc
38+
if attempt < LLM_MAX_RETRIES:
39+
delay = LLM_RETRY_BASE_DELAY * (2**attempt)
40+
logger.warning(
41+
"LLM request failed (%s), retrying in %.1fs (attempt %d/%d)",
42+
type(exc).__name__,
43+
delay,
44+
attempt + 1,
45+
LLM_MAX_RETRIES,
46+
)
47+
await asyncio.sleep(delay)
48+
raise last_exc # type: ignore[misc]
49+
50+
async def _generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
651
raise NotImplementedError

app/llm/providers/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, api_key: str, model: str, base_url: str | None, timeout: floa
2020
self.base_url = (base_url or "https://api.anthropic.com").rstrip("/")
2121
self.timeout = timeout
2222

23-
async def generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
23+
async def _generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
2424
if not self.api_key:
2525
raise ValueError("LLM_API_KEY is required for Anthropic provider")
2626

app/llm/providers/chat_completions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, api_key: str, model: str, base_url: str | None, timeout: floa
1818
self.base_url = (base_url or "https://api.openai.com").rstrip("/")
1919
self.timeout = timeout
2020

21-
async def generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
21+
async def _generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
2222
url = f"{self.base_url}/v1/chat/completions"
2323
payload: Dict[str, Any] = {
2424
"model": self.model,

app/llm/providers/ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, model: str, base_url: str | None, timeout: float = 60) -> Non
1111
self.base_url = (base_url or "http://localhost:11434").rstrip("/")
1212
self.timeout = timeout
1313

14-
async def generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
14+
async def _generate(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
1515
url = f"{self.base_url}/api/chat"
1616
payload = {
1717
"model": self.model,

tests/test_llm.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""Tests for LLM provider selection and retry logic."""
2+
3+
from unittest.mock import patch
4+
5+
import httpx
6+
import pytest
7+
from app.assistant.service import _build_provider
8+
from app.config import Config
9+
from app.llm.base import LLM_MAX_RETRIES, LLMProvider
10+
from app.llm.providers.anthropic import AnthropicProvider
11+
from app.llm.providers.chat_completions import ChatCompletionsProvider
12+
from app.llm.providers.ollama import OllamaProvider
13+
14+
pytestmark = pytest.mark.asyncio
15+
16+
17+
def _make_config(**overrides) -> Config:
18+
defaults = dict(
19+
db_url="sqlite:///:memory:",
20+
mode="read-only",
21+
limit_default=100,
22+
timeout_ms=5000,
23+
enable_ui=True,
24+
enable_explanations=True,
25+
allowed_origins=["http://localhost:8000"],
26+
allow_destructive=False,
27+
llm_provider="openai",
28+
llm_api_key="test-key",
29+
llm_model="test-model",
30+
llm_base_url=None,
31+
openai_api_mode="chat",
32+
llm_timeout_ms=60000,
33+
chat_history_enabled=True,
34+
chat_history_limit=10,
35+
)
36+
defaults.update(overrides)
37+
return Config(**defaults)
38+
39+
40+
# ── Provider selection ─────────────────────────────────────────────────────
41+
42+
43+
def test_build_provider_openai() -> None:
44+
cfg = _make_config(llm_provider="openai")
45+
provider = _build_provider(cfg)
46+
assert isinstance(provider, ChatCompletionsProvider)
47+
assert provider.base_url == "https://api.openai.com"
48+
49+
50+
def test_build_provider_anthropic() -> None:
51+
cfg = _make_config(llm_provider="anthropic")
52+
provider = _build_provider(cfg)
53+
assert isinstance(provider, AnthropicProvider)
54+
55+
56+
def test_build_provider_ollama() -> None:
57+
cfg = _make_config(llm_provider="ollama")
58+
provider = _build_provider(cfg)
59+
assert isinstance(provider, OllamaProvider)
60+
assert provider.base_url == "http://localhost:11434"
61+
62+
63+
def test_build_provider_deepseek() -> None:
64+
cfg = _make_config(llm_provider="deepseek")
65+
provider = _build_provider(cfg)
66+
assert isinstance(provider, ChatCompletionsProvider)
67+
assert provider.base_url == "https://api.deepseek.com"
68+
69+
70+
def test_build_provider_gemini() -> None:
71+
cfg = _make_config(llm_provider="gemini")
72+
provider = _build_provider(cfg)
73+
assert isinstance(provider, ChatCompletionsProvider)
74+
assert "generativelanguage" in provider.base_url
75+
76+
77+
def test_build_provider_custom_base_url() -> None:
78+
cfg = _make_config(llm_provider="openai", llm_base_url="https://my-proxy.example.com")
79+
provider = _build_provider(cfg)
80+
assert isinstance(provider, ChatCompletionsProvider)
81+
assert provider.base_url == "https://my-proxy.example.com"
82+
83+
84+
def test_build_provider_timeout_passed() -> None:
85+
cfg = _make_config(llm_timeout_ms=30000)
86+
provider = _build_provider(cfg)
87+
assert provider.timeout == 30.0
88+
89+
90+
# ── Retry logic ────────────────────────────────────────────────────────────
91+
92+
93+
class _FlakyProvider(LLMProvider):
94+
"""Provider that fails N times then succeeds."""
95+
96+
def __init__(self, fail_times: int, exc: Exception) -> None:
97+
self.fail_times = fail_times
98+
self.exc = exc
99+
self.attempts = 0
100+
101+
async def _generate(self, messages):
102+
self.attempts += 1
103+
if self.attempts <= self.fail_times:
104+
raise self.exc
105+
return {"text": "ok", "raw": {}}
106+
107+
108+
def _make_http_error(status: int) -> httpx.HTTPStatusError:
109+
response = httpx.Response(status_code=status)
110+
return httpx.HTTPStatusError(
111+
message=f"{status}", request=httpx.Request("POST", "http://x"), response=response
112+
)
113+
114+
115+
async def test_retry_on_500() -> None:
116+
provider = _FlakyProvider(fail_times=1, exc=_make_http_error(500))
117+
with patch("app.llm.base.LLM_RETRY_BASE_DELAY", 0):
118+
result = await provider.generate([])
119+
assert result["text"] == "ok"
120+
assert provider.attempts == 2
121+
122+
123+
async def test_retry_on_429() -> None:
124+
provider = _FlakyProvider(fail_times=1, exc=_make_http_error(429))
125+
with patch("app.llm.base.LLM_RETRY_BASE_DELAY", 0):
126+
result = await provider.generate([])
127+
assert result["text"] == "ok"
128+
assert provider.attempts == 2
129+
130+
131+
async def test_no_retry_on_400() -> None:
132+
provider = _FlakyProvider(fail_times=1, exc=_make_http_error(400))
133+
with pytest.raises(httpx.HTTPStatusError):
134+
await provider.generate([])
135+
assert provider.attempts == 1 # no retry for client errors
136+
137+
138+
async def test_retry_on_timeout() -> None:
139+
provider = _FlakyProvider(fail_times=1, exc=httpx.TimeoutException("timeout"))
140+
with patch("app.llm.base.LLM_RETRY_BASE_DELAY", 0):
141+
result = await provider.generate([])
142+
assert result["text"] == "ok"
143+
assert provider.attempts == 2
144+
145+
146+
async def test_retry_on_connect_error() -> None:
147+
provider = _FlakyProvider(fail_times=1, exc=httpx.ConnectError("refused"))
148+
with patch("app.llm.base.LLM_RETRY_BASE_DELAY", 0):
149+
result = await provider.generate([])
150+
assert result["text"] == "ok"
151+
152+
153+
async def test_retry_exhausted_raises() -> None:
154+
provider = _FlakyProvider(fail_times=10, exc=_make_http_error(503))
155+
with patch("app.llm.base.LLM_RETRY_BASE_DELAY", 0):
156+
with pytest.raises(httpx.HTTPStatusError):
157+
await provider.generate([])
158+
assert provider.attempts == 1 + LLM_MAX_RETRIES

tests/test_sessions.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Tests for SSE session management (create, get, expire, GC)."""
2+
3+
import asyncio
4+
import time
5+
from unittest.mock import patch
6+
7+
import pytest
8+
from app import main as main_module
9+
from app.main import (
10+
SESSION_TTL,
11+
_create_session,
12+
_enqueue,
13+
_get_session,
14+
_remove_session,
15+
_sessions,
16+
_sessions_lock,
17+
)
18+
19+
pytestmark = pytest.mark.asyncio
20+
21+
22+
async def _clear_sessions() -> None:
23+
async with _sessions_lock:
24+
_sessions.clear()
25+
26+
27+
async def test_create_and_get_session() -> None:
28+
await _clear_sessions()
29+
sid = await _create_session()
30+
assert sid
31+
queue = await _get_session(sid)
32+
assert queue is not None
33+
34+
35+
async def test_get_nonexistent_session() -> None:
36+
await _clear_sessions()
37+
queue = await _get_session("nonexistent-id")
38+
assert queue is None
39+
40+
41+
async def test_remove_session() -> None:
42+
await _clear_sessions()
43+
sid = await _create_session()
44+
await _remove_session(sid)
45+
queue = await _get_session(sid)
46+
assert queue is None
47+
48+
49+
async def test_remove_nonexistent_session() -> None:
50+
await _clear_sessions()
51+
# Should not raise
52+
await _remove_session("nonexistent-id")
53+
54+
55+
async def test_enqueue_and_dequeue() -> None:
56+
await _clear_sessions()
57+
sid = await _create_session()
58+
ok = await _enqueue(sid, {"test": "payload"})
59+
assert ok
60+
queue = await _get_session(sid)
61+
message = queue.get_nowait()
62+
assert '"test"' in message
63+
64+
65+
async def test_enqueue_nonexistent_session() -> None:
66+
await _clear_sessions()
67+
ok = await _enqueue("nonexistent-id", {"test": "payload"})
68+
assert not ok
69+
70+
71+
async def test_get_session_updates_last_seen() -> None:
72+
await _clear_sessions()
73+
sid = await _create_session()
74+
async with _sessions_lock:
75+
_, ts1 = _sessions[sid]
76+
await asyncio.sleep(0.01)
77+
await _get_session(sid)
78+
async with _sessions_lock:
79+
_, ts2 = _sessions[sid]
80+
assert ts2 >= ts1
81+
82+
83+
async def test_gc_removes_expired_sessions() -> None:
84+
await _clear_sessions()
85+
sid = await _create_session()
86+
87+
# Manually set last_seen to the past
88+
async with _sessions_lock:
89+
queue, _ = _sessions[sid]
90+
_sessions[sid] = (queue, time.time() - SESSION_TTL - 10)
91+
92+
# Run one GC cycle (patch sleep to break after first iteration)
93+
call_count = 0
94+
95+
async def _fake_sleep(seconds):
96+
nonlocal call_count
97+
call_count += 1
98+
if call_count > 1:
99+
raise asyncio.CancelledError
100+
101+
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
102+
try:
103+
await main_module._gc_sessions()
104+
except asyncio.CancelledError:
105+
pass
106+
107+
queue = await _get_session(sid)
108+
assert queue is None # expired session was removed

0 commit comments

Comments
 (0)