diff --git a/cli/serve/app.py b/cli/serve/app.py index df9eeab76..7dc375beb 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -1,6 +1,8 @@ """A simple app that runs an OpenAI compatible server wrapped around a M program.""" +import asyncio import importlib.util +import inspect import os import sys import time @@ -11,6 +13,8 @@ from fastapi import FastAPI from fastapi.responses import JSONResponse +from mellea.backends.model_options import ModelOption + from .models import ( ChatCompletion, ChatCompletionMessage, @@ -53,20 +57,54 @@ def create_openai_error_response( def make_chat_endpoint(module): """Makes a chat endpoint using a custom module.""" + def _build_model_options(request: ChatCompletionRequest) -> dict: + """Build model_options dict from OpenAI-compatible request parameters.""" + excluded_fields = { + # Request structure fields (handled separately) + "messages", # Chat messages - passed separately to serve() + "requirements", # Mellea requirements - passed separately to serve() + # Routing/metadata fields (not generation parameters) + "model", # Model identifier - used for routing, not generation + "n", # Number of completions - not supported in Mellea's model_options + "user", # User tracking ID - metadata, not a generation parameter + "extra", # Pydantic's extra fields dict - unused (see model_config) + } + openai_to_model_option = { + "temperature": ModelOption.TEMPERATURE, + "max_tokens": ModelOption.MAX_NEW_TOKENS, + "seed": ModelOption.SEED, + } + + filtered_options = { + key: value + for key, value in request.model_dump().items() + if key not in excluded_fields + } + return ModelOption.replace_keys(filtered_options, openai_to_model_option) + async def endpoint(request: ChatCompletionRequest): try: completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" created_timestamp = int(time.time()) - output = module.serve( - input=request.messages, - requirements=request.requirements, - model_options={ - k: v - for k, v in request.model_dump().items() - if k not in ["messages", "requirements"] - }, - ) + model_options = _build_model_options(request) + + # Detect if serve is async or sync and handle accordingly + if inspect.iscoroutinefunction(module.serve): + # It's async, await it directly + output = await module.serve( + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) + else: + # It's sync, run in thread pool to avoid blocking event loop + output = await asyncio.to_thread( + module.serve, + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) # Extract usage information from the ModelOutputThunk if available usage = None diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index 428dcab00..2350797ff 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -91,6 +91,10 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: # This will usually be a @@@<>@@@ ModelOption.<> key. new_key = from_to.get(old_key, None) if new_key: + # Skip if old_key and new_key are the same (no-op replacement) + if old_key == new_key: + continue + if new_options.get(new_key, None) is not None: # The key already has a value associated with it in the dict. Leave it be. conflict_log.append( diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 584be0876..fb609dc5f 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -122,6 +122,8 @@ async def test_system_fingerprint_always_none(self, mock_module, sample_request) @pytest.mark.asyncio async def test_model_options_passed_correctly(self, mock_module, sample_request): """Test that model options are passed to serve function correctly.""" + from mellea.backends.model_options import ModelOption + mock_output = ModelOutputThunk("Test response") mock_module.serve.return_value = mock_output @@ -134,11 +136,12 @@ async def test_model_options_passed_correctly(self, mock_module, sample_request) assert "model_options" in call_args.kwargs model_options = call_args.kwargs["model_options"] - # Should include temperature and max_tokens but not messages/requirements - assert "temperature" in model_options - assert model_options["temperature"] == 0.7 - assert "max_tokens" in model_options - assert model_options["max_tokens"] == 100 + # Should include ModelOption keys for temperature and max_tokens + # Note: TEMPERATURE is just "temperature" (not a sentinel), so it stays as-is + assert ModelOption.TEMPERATURE in model_options + assert model_options[ModelOption.TEMPERATURE] == 0.7 + assert ModelOption.MAX_NEW_TOKENS in model_options + assert model_options[ModelOption.MAX_NEW_TOKENS] == 100 assert "messages" not in model_options assert "requirements" not in model_options diff --git a/test/cli/test_serve_sync_async.py b/test/cli/test_serve_sync_async.py new file mode 100644 index 000000000..8e0dab9f8 --- /dev/null +++ b/test/cli/test_serve_sync_async.py @@ -0,0 +1,255 @@ +"""Tests for sync/async serve function handling in m serve.""" + +import asyncio +from unittest.mock import AsyncMock, Mock + +import pytest + +from cli.serve.app import make_chat_endpoint +from cli.serve.models import ChatCompletionRequest, ChatMessage +from mellea.backends.model_options import ModelOption +from mellea.core import ModelOutputThunk + + +@pytest.fixture +def mock_sync_module(): + """Create a mock module with a synchronous serve function.""" + module = Mock() + module.__name__ = "test_sync_module" + + def sync_serve(input, requirements=None, model_options=None): + """Synchronous serve function.""" + # Simulate some work + return ModelOutputThunk(f"Sync response to: {input[-1].content}") + + # Use Mock to wrap the function so we can track calls + module.serve = Mock(side_effect=sync_serve) + return module + + +@pytest.fixture +def mock_async_module(): + """Create a mock module with an asynchronous serve function.""" + module = Mock() + module.__name__ = "test_async_module" + + async def async_serve(input, requirements=None, model_options=None): + """Asynchronous serve function.""" + # Simulate async work + await asyncio.sleep(0.01) + return ModelOutputThunk(f"Async response to: {input[-1].content}") + + module.serve = AsyncMock(side_effect=async_serve) + return module + + +@pytest.fixture +def mock_slow_sync_module(): + """Create a mock module with a slow synchronous serve function.""" + module = Mock() + module.__name__ = "test_slow_sync_module" + + def slow_sync_serve(input, requirements=None, model_options=None): + """Slow synchronous serve function that would block event loop.""" + import time + + time.sleep(1) # Simulate blocking work with a clearer timing signal + return ModelOutputThunk(f"Slow sync response to: {input[-1].content}") + + module.serve = slow_sync_serve + return module + + +class TestSyncAsyncServeHandling: + """Test that serve handles both sync and async serve functions correctly.""" + + @pytest.mark.asyncio + async def test_sync_serve_function(self, mock_sync_module): + """Test that synchronous serve functions work correctly.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello sync!")], + ) + + response = await endpoint(request) + + assert response.choices[0].message.content == "Sync response to: Hello sync!" + assert response.model == "test-model" + assert response.object == "chat.completion" + + @pytest.mark.asyncio + async def test_async_serve_function(self, mock_async_module): + """Test that asynchronous serve functions work correctly.""" + endpoint = make_chat_endpoint(mock_async_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello async!")], + ) + + response = await endpoint(request) + + assert response.choices[0].message.content == "Async response to: Hello async!" + assert response.model == "test-model" + assert response.object == "chat.completion" + + @pytest.mark.asyncio + async def test_slow_sync_does_not_block(self, mock_slow_sync_module): + """Test that slow sync functions run in thread pool and don't block event loop. + + This test verifies non-blocking behavior by measuring timing. If the sync + function blocked the event loop, two sequential calls would take 2x the time. + With proper threading, they should overlap and take only slightly more than 1x. + """ + import time + + endpoint = make_chat_endpoint(mock_slow_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello slow!")], + ) + + # Time two concurrent requests + start = time.time() + results = await asyncio.gather(endpoint(request), endpoint(request)) + elapsed = time.time() - start + + # If blocking: would take ~2s (1s + 1s sequentially) + # If non-blocking: should take ~1s (both run concurrently in threads) + # Allow some overhead, but should still be well below the blocking case. + assert elapsed < 2, ( + f"Took {elapsed:.3f}s - appears to be blocking (expected ~1s)" + ) + assert all( + r.choices[0].message.content == "Slow sync response to: Hello slow!" + for r in results + ) + + @pytest.mark.asyncio + async def test_concurrent_requests_with_sync_serve(self, mock_slow_sync_module): + """Test that multiple sync requests can be handled concurrently.""" + endpoint = make_chat_endpoint(mock_slow_sync_module) + + requests = [ + ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content=f"Request {i}")], + ) + for i in range(3) + ] + + # Run requests concurrently + responses = await asyncio.gather(*[endpoint(req) for req in requests]) + + # All should complete successfully + assert len(responses) == 3 + for i, response in enumerate(responses): + assert ( + response.choices[0].message.content + == f"Slow sync response to: Request {i}" + ) + + @pytest.mark.asyncio + async def test_requirements_passed_to_serve(self, mock_sync_module): + """Test that requirements are correctly passed to serve function.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Test")], + requirements=["req1", "req2"], + ) + + await endpoint(request) + + # Verify serve was called with requirements + mock_sync_module.serve.assert_called_once() + call_kwargs = mock_sync_module.serve.call_args.kwargs + assert call_kwargs["requirements"] == ["req1", "req2"] + + @pytest.mark.asyncio + async def test_model_options_passed_to_serve(self, mock_sync_module): + """Test that model options are correctly passed to serve function.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Test")], + temperature=0.7, + max_tokens=100, + ) + + await endpoint(request) + + # Verify serve was called with model_options + mock_sync_module.serve.assert_called_once() + call_kwargs = mock_sync_module.serve.call_args.kwargs + model_options = call_kwargs["model_options"] + assert ModelOption.TEMPERATURE in model_options + assert ModelOption.MAX_NEW_TOKENS in model_options + + @pytest.mark.asyncio + async def test_openai_params_mapped_to_model_options(self, mock_sync_module): + """Test that OpenAI parameters are mapped to ModelOption sentinels.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Test")], + temperature=0.8, + max_tokens=150, + seed=42, + ) + + await endpoint(request) + + # Verify parameters are mapped correctly + mock_sync_module.serve.assert_called_once() + call_kwargs = mock_sync_module.serve.call_args.kwargs + model_options = call_kwargs["model_options"] + + assert model_options[ModelOption.TEMPERATURE] == 0.8 + assert model_options[ModelOption.MAX_NEW_TOKENS] == 150 + assert model_options[ModelOption.SEED] == 42 + + +class TestEndpointIntegration: + """Integration tests for the full endpoint.""" + + def test_endpoint_name_set_correctly(self, mock_sync_module): + """Test that endpoint function name is set correctly.""" + endpoint = make_chat_endpoint(mock_sync_module) + assert endpoint.__name__ == "chat_test_sync_module_endpoint" + + @pytest.mark.asyncio + async def test_completion_id_generated(self, mock_sync_module): + """Test that each response gets a unique completion ID.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", messages=[ChatMessage(role="user", content="Test")] + ) + + response1 = await endpoint(request) + response2 = await endpoint(request) + + assert response1.id.startswith("chatcmpl-") + assert response2.id.startswith("chatcmpl-") + assert response1.id != response2.id + + @pytest.mark.asyncio + async def test_timestamp_generated(self, mock_sync_module): + """Test that response includes a timestamp.""" + endpoint = make_chat_endpoint(mock_sync_module) + + request = ChatCompletionRequest( + model="test-model", messages=[ChatMessage(role="user", content="Test")] + ) + + response = await endpoint(request) + + assert isinstance(response.created, int) + assert response.created > 0