diff --git a/lf_toolkit/evaluation/image_upload.py b/lf_toolkit/evaluation/image_upload.py index dff9233..d0d69f4 100644 --- a/lf_toolkit/evaluation/image_upload.py +++ b/lf_toolkit/evaluation/image_upload.py @@ -101,7 +101,7 @@ def get_aws_signed_request(full_url, buffer, mime_type): aws_request = AWSRequest( method='PUT', url=full_url, - data=buffer, + data=data, headers=headers ) diff --git a/lf_toolkit/io/rpc_handler.py b/lf_toolkit/io/rpc_handler.py index fe3fadb..354ae95 100644 --- a/lf_toolkit/io/rpc_handler.py +++ b/lf_toolkit/io/rpc_handler.py @@ -4,6 +4,7 @@ from jsonrpcserver import Success from jsonrpcserver import async_dispatch +from ..shared import Command from .handler import Handler @@ -23,10 +24,10 @@ async def dispatch(self, req: str) -> str: ) -def jsonrpc_handler(handler: Handler, name: str): +def jsonrpc_handler(handler: Handler, name: Command): async def wrapped(req: dict): try: - result = await handler.handle(name, req) + result = await handler.handle(name, {"params": req}) return Success(result) except Exception as e: return Error(0, str(e), e) diff --git a/lf_toolkit/io/stdio_server.py b/lf_toolkit/io/stdio_server.py index cbffea2..3d6d567 100644 --- a/lf_toolkit/io/stdio_server.py +++ b/lf_toolkit/io/stdio_server.py @@ -2,6 +2,7 @@ from typing import Optional +import anyio from anyio.streams.file import FileReadStream from anyio.streams.file import FileWriteStream from anyio.streams.stapled import StapledByteStream @@ -15,9 +16,10 @@ class StdioClient(StreamIO): def __init__(self): + self._stdout_buffer = sys.stdout.buffer self.stream = StapledByteStream( - FileWriteStream(sys.stdout), - FileReadStream(sys.stdin), + FileWriteStream(self._stdout_buffer), + FileReadStream(sys.stdin.buffer), ) async def read(self, size: int) -> bytes: @@ -25,7 +27,7 @@ async def read(self, size: int) -> bytes: async def write(self, data: bytes): await self.stream.send(data) - await self.stream.flush() + await anyio.to_thread.run_sync(self._stdout_buffer.flush) async def close(self): await self.stream.aclose() @@ -37,10 +39,11 @@ class StdioServer(StreamServer): def __init__(self, handler: Optional[Handler] = None): super().__init__(handler) - self._client = StdioClient() def wrap_io(self, client: StreamIO) -> StreamIO: return PrefixStreamIO(client) async def run(self): + print("StdioServer started", file=sys.stderr, flush=True) + self._client = StdioClient() await self._handle_client(self._client) diff --git a/lf_toolkit/io/stream_io.py b/lf_toolkit/io/stream_io.py index 34c835f..223465b 100644 --- a/lf_toolkit/io/stream_io.py +++ b/lf_toolkit/io/stream_io.py @@ -63,10 +63,11 @@ async def read(self, size: int) -> bytes: if content_length == 0: raise ValueError("Content-Length header not found or is zero") - if content_length > size: - raise ValueError("Content-Length is larger than the read size") - - return await self.base.read(content_length) + data = b"" + while len(data) < content_length: + chunk = await self.base.read(content_length - len(data)) + data += chunk + return data async def write(self, data: bytes): response_headers_str = f"Content-Length: {len(data)}\r\n\r\n" @@ -84,22 +85,27 @@ async def _handle_client(self, client: StreamIO): while True: try: + import sys + print("waiting for data...", file=sys.stderr, flush=True) data = await io.read(4096) + print(f"got data: {data[:80]}", file=sys.stderr, flush=True) if not data: - # print("Received empty data") break + print("dispatching...", file=sys.stderr, flush=True) response = await self.dispatch(data.decode("utf-8")) + print(f"got response: {str(response)[:80]}", file=sys.stderr, flush=True) await io.write(response.encode("utf-8")) + print("wrote response", file=sys.stderr, flush=True) except anyio.EndOfStream: - # print("Client disconnected") break except anyio.ClosedResourceError: - # print("Client disconnected") break except Exception as e: - print(f"Exception: {e}") - finally: - await client.close() + import traceback + traceback.print_exc(file=sys.stderr) + break + + await client.close() diff --git a/poetry.lock b/poetry.lock index 58fff1f..3fcf25e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -32,7 +32,7 @@ version = "4.6.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "anyio-4.6.0-py3-none-any.whl", hash = "sha256:c7d2e9d63e31599eeb636c8c5c03a7e108d73b345f064f1c19fdc87b79036a9a"}, {file = "anyio-4.6.0.tar.gz", hash = "sha256:137b4559cbb034c477165047febb6ff83f390fc3b20bf181c1fc0a728cb8beeb"}, @@ -912,7 +912,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -1839,6 +1839,22 @@ pluggy = ">=1.5,<2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-anyio" +version = "0.0.0" +description = "The pytest anyio plugin is built into anyio. You don't need this package." +optional = false +python-versions = "*" +groups = ["dev"] +files = [ + {file = "pytest-anyio-0.0.0.tar.gz", hash = "sha256:b41234e9e9ad7ea1dbfefcc1d6891b23d5ef7c9f07ccf804c13a9cc338571fd3"}, + {file = "pytest_anyio-0.0.0-py2.py3-none-any.whl", hash = "sha256:dc8b5c4741cb16ff90be37fddd585ca943ed12bbeb563de7ace6cd94441d8746"}, +] + +[package.dependencies] +anyio = "*" +pytest = "*" + [[package]] name = "pytest-asyncio" version = "1.2.0" @@ -2398,7 +2414,7 @@ version = "1.3.1" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -2848,4 +2864,4 @@ parsing = ["antlr4-python3-runtime", "lark", "latex2sympy"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "9dc3f7e12199191cf41834205dbb2705b1e1e4b2dd851b1bb57e312d3c4e8a8b" +content-hash = "828a10ad95eed705e623f10d27ef6d21568caf98e05636c91cca9246c34b7b58" diff --git a/pyproject.toml b/pyproject.toml index e06cf72..f1c6066 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ pytest-cov = "5.0.0" ########################## # extras ########################## +pytest-anyio = "^0.0.0" [tool.poetry.extras] parsing = ["antlr4-python3-runtime", "lark", "latex2sympy"] diff --git a/tests/io/file_server.py b/tests/io/file_server_test.py similarity index 100% rename from tests/io/file_server.py rename to tests/io/file_server_test.py diff --git a/tests/io/stream_io_test.py b/tests/io/stream_io_test.py new file mode 100644 index 0000000..09bb277 --- /dev/null +++ b/tests/io/stream_io_test.py @@ -0,0 +1,237 @@ +import subprocess +import sys + +import pytest +import anyio + +from lf_toolkit.io.stream_io import StreamIO, PrefixStreamIO, StreamServer +from lf_toolkit.io.stdio_server import StdioServer + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_framed_message(payload: str) -> bytes: + """Wrap a JSON string in Content-Length framing.""" + body = payload.encode("utf-8") + header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8") + return header + body + + +class FakeStreamIO(StreamIO): + """ + Simulates a bidirectional byte stream. + Feed messages via feed(), read responses via responses. + """ + + def __init__(self): + self._buffer = b"" + self.responses = [] + self.close_count = 0 + + def feed(self, data: bytes): + self._buffer += data + + async def read(self, size: int) -> bytes: + if not self._buffer: + raise anyio.EndOfStream() + chunk = self._buffer[:size] + self._buffer = self._buffer[size:] + return chunk + + async def write(self, data: bytes): + self.responses.append(data) + + async def close(self): + self.close_count += 1 + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestStdioServer: + + @pytest.fixture + def stream(self): + return FakeStreamIO() + + @pytest.fixture + def server(self): + return StdioServer() + + @pytest.mark.anyio + async def test_handles_multiple_messages(self, stream, server): + """ + Core fix test: the server must process multiple messages in a single + session without closing the connection between them. + """ + stream.feed(make_framed_message('{"jsonrpc":"2.0","method":"eval","params":{},"id":1}')) + stream.feed(make_framed_message('{"jsonrpc":"2.0","method":"eval","params":{},"id":2}')) + stream.feed(make_framed_message('{"jsonrpc":"2.0","method":"eval","params":{},"id":3}')) + + await server._handle_client(stream) + + assert len(stream.responses) == 3, ( + f"Expected 3 responses but got {len(stream.responses)}. " + "Server likely closed the connection after the first message." + ) + + @pytest.mark.anyio + async def test_closes_only_once(self, stream, server): + """ + The client connection should be closed exactly once — after the loop + exits — not once per message. + """ + stream.feed(make_framed_message('{"jsonrpc":"2.0","method":"eval","params":{},"id":1}')) + stream.feed(make_framed_message('{"jsonrpc":"2.0","method":"eval","params":{},"id":2}')) + + await server._handle_client(stream) + + assert stream.close_count == 1, ( + f"Expected close() to be called once, but it was called " + f"{stream.close_count} times. This is the original bug." + ) + + @pytest.mark.anyio + async def test_single_message(self, stream, server): + """A single message round-trip should work correctly.""" + stream.feed(make_framed_message('{"jsonrpc":"2.0","method":"eval","params":{},"id":1}')) + + await server._handle_client(stream) + + assert len(stream.responses) == 1 + # Response is a framed JSON-RPC envelope + assert b"Content-Length:" in stream.responses[0] + assert b"jsonrpc" in stream.responses[0] + + @pytest.mark.anyio + async def test_closes_on_empty_stream(self, stream, server): + """Server should exit cleanly when the stream ends with no data.""" + await server._handle_client(stream) + + assert stream.close_count == 1 + + @pytest.mark.anyio + async def test_response_content(self, stream, server): + """Verify a response is returned for each message sent.""" + messages = [ + '{"jsonrpc":"2.0","method":"eval","params":{},"id":1}', + '{"jsonrpc":"2.0","method":"preview","params":{},"id":2}', + ] + + for msg in messages: + stream.feed(make_framed_message(msg)) + + await server._handle_client(stream) + + assert len(stream.responses) == 2 + for response in stream.responses: + assert b"Content-Length:" in response + assert b"jsonrpc" in response + + +class TestPrefixStreamIO: + + @pytest.fixture + def stream(self): + return FakeStreamIO() + + @pytest.mark.anyio + async def test_framing_round_trip(self, stream): + """PrefixStreamIO correctly encodes and decodes Content-Length framing.""" + prefix_io = PrefixStreamIO(stream) + + payload = b'{"command": "eval"}' + header = f"Content-Length: {len(payload)}\r\n\r\n".encode() + stream.feed(header + payload) + + result = await prefix_io.read(4096) + assert result == payload + + @pytest.mark.anyio + async def test_write_includes_content_length_header(self, stream): + """PrefixStreamIO write includes correct Content-Length header.""" + prefix_io = PrefixStreamIO(stream) + + payload = b'{"result": "ok"}' + await prefix_io.write(payload) + + assert len(stream.responses) == 1 + written = stream.responses[0] + assert b"Content-Length:" in written + assert f"{len(payload)}".encode() in written + assert payload in written + + @pytest.mark.anyio + async def test_raises_on_missing_content_length(self, stream): + """PrefixStreamIO should raise if Content-Length header is absent.""" + prefix_io = PrefixStreamIO(stream) + + stream.feed(b"X-Custom-Header: something\r\n\r\n") + + with pytest.raises(ValueError, match="Content-Length"): + await prefix_io.read(4096) + + @pytest.mark.anyio + async def test_large_payload_does_not_raise(self, stream): + """Payloads larger than 4096 bytes must be read without raising.""" + prefix_io = PrefixStreamIO(stream) + + payload = b"x" * 8192 + header = f"Content-Length: {len(payload)}\r\n\r\n".encode() + stream.feed(header + payload) + + result = await prefix_io.read(4096) + assert result == payload + + @pytest.mark.anyio + async def test_exact_read_of_partial_chunks(self, stream): + """All bytes are read even when the underlying stream delivers chunks smaller than content_length.""" + prefix_io = PrefixStreamIO(stream) + + payload = b"a" * 100 + header = f"Content-Length: {len(payload)}\r\n\r\n".encode() + # Feed header and payload as separate tiny chunks (10 bytes each) + full = header + payload + for i in range(0, len(full), 10): + stream.feed(full[i:i + 10]) + + result = await prefix_io.read(4096) + assert result == payload + assert len(result) == 100 + + +class TestStdioServerSubprocess: + + def test_binary_pipe_roundtrip(self): + """ + Spawn the StdioServer as a subprocess and pipe a framed JSON-RPC + request to its stdin (as raw bytes). Confirms sys.stdin.buffer / + sys.stdout.buffer is used — text-mode streams would break this. + """ + msg = b'{"jsonrpc":"2.0","id":1,"method":"eval","params":{}}' + frame = f"Content-Length: {len(msg)}\r\n\r\n".encode() + msg + + proc = subprocess.Popen( + [sys.executable, "-c", + "import anyio; from lf_toolkit.io.stdio_server import StdioServer; " + "anyio.run(StdioServer().run)"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + stdout, stderr = proc.communicate(input=frame, timeout=5) + + # Must receive a framed response + assert b"Content-Length:" in stdout, ( + f"No framed response received.\nstderr: {stderr.decode()}" + ) + assert b"jsonrpc" in stdout