Skip to content

Commit fe90a96

Browse files
committed
Added test suite for stream_io
1 parent b375d16 commit fe90a96

File tree

3 files changed

+259
-4
lines changed

3 files changed

+259
-4
lines changed

poetry.lock

Lines changed: 20 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pytest-cov = "5.0.0"
5959
##########################
6060
# extras
6161
##########################
62+
pytest-anyio = "^0.0.0"
6263

6364
[tool.poetry.extras]
6465
parsing = ["antlr4-python3-runtime", "lark", "latex2sympy"]

tests/io/stream_io_test.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import pytest
2+
import anyio
3+
4+
from lf_toolkit.io.stream_io import StreamIO, PrefixStreamIO, StreamServer
5+
6+
7+
@pytest.fixture
8+
def anyio_backend():
9+
return "asyncio"
10+
11+
12+
13+
def make_framed_message(payload: str) -> bytes:
14+
"""Wrap a JSON string in Content-Length framing."""
15+
body = payload.encode("utf-8")
16+
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
17+
return header + body
18+
19+
20+
class FakeStreamIO(StreamIO):
21+
"""
22+
Simulates a bidirectional byte stream.
23+
Feed messages via feed(), read responses via responses.
24+
"""
25+
26+
def __init__(self):
27+
self._buffer = b""
28+
self.responses = []
29+
self.close_count = 0
30+
31+
def feed(self, data: bytes):
32+
self._buffer += data
33+
34+
async def read(self, size: int) -> bytes:
35+
if not self._buffer:
36+
raise anyio.EndOfStream()
37+
chunk = self._buffer[:size]
38+
self._buffer = self._buffer[size:]
39+
return chunk
40+
41+
async def write(self, data: bytes):
42+
self.responses.append(data)
43+
44+
async def close(self):
45+
self.close_count += 1
46+
47+
48+
class EchoServer(StreamServer):
49+
"""
50+
Concrete StreamServer for testing.
51+
- run() is required by BaseServer (abstract) but not used in tests
52+
since we call _handle_client directly.
53+
- dispatch() is overridden to echo the raw request back, bypassing
54+
the real JsonRpcHandler so tests stay self-contained.
55+
"""
56+
57+
async def run(self):
58+
pass
59+
60+
async def dispatch(self, data: str) -> str:
61+
return data
62+
63+
64+
class BuggyStreamServer(StreamServer):
65+
"""
66+
Reproduces the original bug by overriding _handle_client with
67+
close() inside the finally block.
68+
"""
69+
70+
async def run(self):
71+
pass
72+
73+
async def dispatch(self, data: str) -> str:
74+
return data
75+
76+
async def _handle_client(self, client: StreamIO):
77+
io = self.wrap_io(client)
78+
while True:
79+
try:
80+
data = await io.read(4096)
81+
if not data:
82+
break
83+
response = await self.dispatch(data.decode("utf-8"))
84+
await io.write(response.encode("utf-8"))
85+
except anyio.EndOfStream:
86+
break
87+
except anyio.ClosedResourceError:
88+
break
89+
except Exception as e:
90+
print(f"Exception: {e}")
91+
finally:
92+
await client.close() # BUG: closes after every message
93+
94+
95+
# ---------------------------------------------------------------------------
96+
# Tests
97+
# ---------------------------------------------------------------------------
98+
99+
class TestStreamServer:
100+
101+
@pytest.fixture
102+
def stream(self):
103+
return FakeStreamIO()
104+
105+
@pytest.fixture
106+
def server(self):
107+
return EchoServer()
108+
109+
@pytest.fixture
110+
def buggy_server(self):
111+
return BuggyStreamServer()
112+
113+
@pytest.mark.anyio
114+
async def test_handles_multiple_messages(self, stream, server):
115+
"""
116+
Core fix test: the server must process multiple messages in a single
117+
session without closing the connection between them.
118+
"""
119+
stream.feed(make_framed_message('{"command": "eval", "id": 1}'))
120+
stream.feed(make_framed_message('{"command": "eval", "id": 2}'))
121+
stream.feed(make_framed_message('{"command": "eval", "id": 3}'))
122+
123+
await server._handle_client(stream)
124+
125+
assert len(stream.responses) == 3, (
126+
f"Expected 3 responses but got {len(stream.responses)}. "
127+
"Server likely closed the connection after the first message."
128+
)
129+
130+
@pytest.mark.anyio
131+
async def test_closes_only_once(self, stream, server):
132+
"""
133+
The client connection should be closed exactly once — after the loop
134+
exits — not once per message.
135+
"""
136+
stream.feed(make_framed_message('{"id": 1}'))
137+
stream.feed(make_framed_message('{"id": 2}'))
138+
139+
await server._handle_client(stream)
140+
141+
assert stream.close_count == 1, (
142+
f"Expected close() to be called once, but it was called "
143+
f"{stream.close_count} times. This is the original bug."
144+
)
145+
146+
@pytest.mark.anyio
147+
async def test_buggy_server_closes_after_each_message(self, stream, buggy_server):
148+
"""
149+
Demonstrates the original bug: close() in the finally block causes
150+
the stream to be closed after every message, not just at the end.
151+
"""
152+
stream.feed(make_framed_message('{"id": 1}'))
153+
stream.feed(make_framed_message('{"id": 2}'))
154+
155+
await buggy_server._handle_client(stream)
156+
157+
assert stream.close_count > 1, (
158+
"Expected buggy server to call close() more than once, "
159+
"confirming the bug exists in the original code."
160+
)
161+
162+
@pytest.mark.anyio
163+
async def test_single_message(self, stream, server):
164+
"""A single message round-trip should work correctly."""
165+
payload = '{"command": "eval", "response": "test"}'
166+
stream.feed(make_framed_message(payload))
167+
168+
await server._handle_client(stream)
169+
170+
assert len(stream.responses) == 1
171+
assert payload.encode() in stream.responses[0]
172+
173+
@pytest.mark.anyio
174+
async def test_closes_on_empty_stream(self, stream, server):
175+
"""Server should exit cleanly when the stream ends with no data."""
176+
await server._handle_client(stream)
177+
178+
assert stream.close_count == 1
179+
180+
@pytest.mark.anyio
181+
async def test_response_content(self, stream, server):
182+
"""Verify the actual response content is correct across messages."""
183+
messages = [
184+
'{"id": 1, "command": "eval"}',
185+
'{"id": 2, "command": "preview"}',
186+
]
187+
188+
for msg in messages:
189+
stream.feed(make_framed_message(msg))
190+
191+
await server._handle_client(stream)
192+
193+
assert len(stream.responses) == 2
194+
for i, msg in enumerate(messages):
195+
assert msg.encode() in stream.responses[i]
196+
197+
198+
class TestPrefixStreamIO:
199+
200+
@pytest.fixture
201+
def stream(self):
202+
return FakeStreamIO()
203+
204+
@pytest.mark.anyio
205+
async def test_framing_round_trip(self, stream):
206+
"""PrefixStreamIO correctly encodes and decodes Content-Length framing."""
207+
prefix_io = PrefixStreamIO(stream)
208+
209+
payload = b'{"command": "eval"}'
210+
header = f"Content-Length: {len(payload)}\r\n\r\n".encode()
211+
stream.feed(header + payload)
212+
213+
result = await prefix_io.read(4096)
214+
assert result == payload
215+
216+
@pytest.mark.anyio
217+
async def test_write_includes_content_length_header(self, stream):
218+
"""PrefixStreamIO write includes correct Content-Length header."""
219+
prefix_io = PrefixStreamIO(stream)
220+
221+
payload = b'{"result": "ok"}'
222+
await prefix_io.write(payload)
223+
224+
assert len(stream.responses) == 1
225+
written = stream.responses[0]
226+
assert b"Content-Length:" in written
227+
assert f"{len(payload)}".encode() in written
228+
assert payload in written
229+
230+
@pytest.mark.anyio
231+
async def test_raises_on_missing_content_length(self, stream):
232+
"""PrefixStreamIO should raise if Content-Length header is absent."""
233+
prefix_io = PrefixStreamIO(stream)
234+
235+
stream.feed(b"X-Custom-Header: something\r\n\r\n")
236+
237+
with pytest.raises(ValueError, match="Content-Length"):
238+
await prefix_io.read(4096)

0 commit comments

Comments
 (0)