diff --git a/src/labthings_fastapi/server/fallback.py b/src/labthings_fastapi/server/fallback.py index 1d4a6c1f..0350e865 100644 --- a/src/labthings_fastapi/server/fallback.py +++ b/src/labthings_fastapi/server/fallback.py @@ -16,6 +16,7 @@ from typing import Any, TYPE_CHECKING from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from jinja2 import Environment, BaseLoader, select_autoescape from starlette.responses import RedirectResponse @@ -149,6 +150,15 @@ def fallback_page(self) -> HTMLResponse: app = FallbackApp() +# Add middleware so contacting the the fallback server doesn't throw CORS errors. +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + @app.get("/") async def root() -> HTMLResponse: @@ -159,6 +169,17 @@ async def root() -> HTMLResponse: return app.fallback_page() +@app.get("/labthings_fallback") +async def fallback_route() -> bool: + """Return True, this is a LabThings Fallback Server. + + Use this to check over the API if this is a LabThings Fallback Server. + + :return: returns True. This is a LabThings Fallback Server. + """ + return True + + def _format_error_and_traceback(context: FallbackContext) -> tuple[str, str]: """Format the error and traceback. diff --git a/tests/test_fallback.py b/tests/test_fallback.py index 572473b4..e74bad60 100644 --- a/tests/test_fallback.py +++ b/tests/test_fallback.py @@ -158,6 +158,14 @@ def test_fallback_with_log(): assert "Fake log content" in html +def test_fallback_identification(): + """Test the server identifies as a fallback server.""" + app.set_context(FallbackContext()) + with TestClient(app) as client: + response = client.get("/labthings_fallback") + assert response.json() is True + + def test_actual_server_fallback(): """Test that the the server configures its startup failure correctly.