diff --git a/fasthtml/core.py b/fasthtml/core.py index f7bc5e69..3c194976 100644 --- a/fasthtml/core.py +++ b/fasthtml/core.py @@ -566,6 +566,22 @@ def def_hdrs(htmx=True, surreal=True): document.body.addEventListener('htmx:wsAfterMessage', sendmsg); };""")) +# %% ../nbs/api/00_core.ipynb #95601256 +def _wrap_lifespan(lifespan, on_startup, on_shutdown): + "Wrap a lifespan context manager with on_startup/on_shutdown callbacks." + on_startup,on_shutdown = listify(on_startup),listify(on_shutdown) + if not on_startup and not on_shutdown: return lifespan + @contextlib.asynccontextmanager + async def _lifespan(app): + for h in on_startup: await h() if inspect.iscoroutinefunction(h) else h() + try: + if lifespan: + async with lifespan(app) as state: yield state + else: yield + finally: + for h in on_shutdown: await h() if inspect.iscoroutinefunction(h) else h() + return _lifespan + # %% ../nbs/api/00_core.ipynb #3327a1e9 class FastHTML(Starlette): def __init__(self, debug=False, routes=None, middleware=None, title: str = "FastHTML page", exception_handlers=None, @@ -586,7 +602,7 @@ def __init__(self, debug=False, routes=None, middleware=None, title: str = "Fast from IPython.display import display,HTML if nb_hdrs: display(HTML(to_xml(tuple(hdrs)))) middleware.append(cors_allow) - on_startup,on_shutdown = listify(on_startup) or None,listify(on_shutdown) or None + lifespan = _wrap_lifespan(lifespan, on_startup, on_shutdown) self.lifespan,self.hdrs,self.ftrs = lifespan,hdrs,ftrs self.body_wrap,self.before,self.after,self.htmlkw,self.bodykw = body_wrap,before,after,htmlkw,bodykw self.secret_key = get_key(secret_key, key_fname) @@ -600,7 +616,7 @@ def __init__(self, debug=False, routes=None, middleware=None, title: str = "Fast def _not_found(req, exc): return Response('404 Not Found', status_code=404) exception_handlers[404] = _not_found excs = {k:_wrap_ex(v, k, hdrs, ftrs, htmlkw, bodykw, body_wrap=body_wrap) for k,v in exception_handlers.items()} - super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan) + super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, lifespan=lifespan) # %% ../nbs/api/00_core.ipynb #dce68049 class HostRoute(Route): diff --git a/nbs/api/00_core.ipynb b/nbs/api/00_core.ipynb index b75bc22c..379fa4c8 100644 --- a/nbs/api/00_core.ipynb +++ b/nbs/api/00_core.ipynb @@ -1715,6 +1715,30 @@ " };\"\"\"))" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "95601256", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def _wrap_lifespan(lifespan, on_startup, on_shutdown):\n", + " \"Wrap a lifespan context manager with on_startup/on_shutdown callbacks.\"\n", + " on_startup,on_shutdown = listify(on_startup),listify(on_shutdown)\n", + " if not on_startup and not on_shutdown: return lifespan\n", + " @contextlib.asynccontextmanager\n", + " async def _lifespan(app):\n", + " for h in on_startup: await h() if inspect.iscoroutinefunction(h) else h()\n", + " try:\n", + " if lifespan:\n", + " async with lifespan(app) as state: yield state\n", + " else: yield\n", + " finally:\n", + " for h in on_shutdown: await h() if inspect.iscoroutinefunction(h) else h()\n", + " return _lifespan" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1742,7 +1766,7 @@ " from IPython.display import display,HTML\n", " if nb_hdrs: display(HTML(to_xml(tuple(hdrs))))\n", " middleware.append(cors_allow)\n", - " on_startup,on_shutdown = listify(on_startup) or None,listify(on_shutdown) or None\n", + " lifespan = _wrap_lifespan(lifespan, on_startup, on_shutdown)\n", " self.lifespan,self.hdrs,self.ftrs = lifespan,hdrs,ftrs\n", " self.body_wrap,self.before,self.after,self.htmlkw,self.bodykw = body_wrap,before,after,htmlkw,bodykw\n", " self.secret_key = get_key(secret_key, key_fname)\n", @@ -1756,7 +1780,7 @@ " def _not_found(req, exc): return Response('404 Not Found', status_code=404)\n", " exception_handlers[404] = _not_found\n", " excs = {k:_wrap_ex(v, k, hdrs, ftrs, htmlkw, bodykw, body_wrap=body_wrap) for k,v in exception_handlers.items()}\n", - " super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)" + " super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, lifespan=lifespan)" ] }, { diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py new file mode 100644 index 00000000..047ce67c --- /dev/null +++ b/tests/test_lifespan.py @@ -0,0 +1,76 @@ +from fasthtml.common import * +from starlette.testclient import TestClient +import contextlib + +# Test basic app creation works (previously crashed with Starlette 1.0) +def test_basic_app(): + app, rt = fast_app() + cli = TestClient(app) + @rt('/') + def get(): return P('hello') + res = cli.get('/') + assert 'hello' in res.text + +def test_on_startup_shutdown(): + started, stopped = [], [] + app = FastHTML(on_startup=[lambda: started.append(1)], on_shutdown=[lambda: stopped.append(1)]) + cli = TestClient(app) + with cli: + assert started == [1] + assert stopped == [1] + +def test_lifespan_only(): + state = [] + @contextlib.asynccontextmanager + async def lifespan(app): + state.append('started') + yield + state.append('stopped') + app = FastHTML(lifespan=lifespan) + cli = TestClient(app) + with cli: + assert state == ['started'] + assert state == ['started', 'stopped'] + +def test_lifespan_with_startup_shutdown(): + order = [] + @contextlib.asynccontextmanager + async def lifespan(app): + order.append('lifespan_start') + yield + order.append('lifespan_stop') + app = FastHTML( + lifespan=lifespan, + on_startup=[lambda: order.append('on_startup')], + on_shutdown=[lambda: order.append('on_shutdown')], + ) + cli = TestClient(app) + with cli: + assert order == ['on_startup', 'lifespan_start'] + assert order == ['on_startup', 'lifespan_start', 'lifespan_stop', 'on_shutdown'] + +def test_async_startup_shutdown(): + state = [] + async def astart(): state.append('async_start') + async def astop(): state.append('async_stop') + app = FastHTML(on_startup=[astart], on_shutdown=[astop]) + cli = TestClient(app) + with cli: + assert state == ['async_start'] + assert state == ['async_start', 'async_stop'] + +def test_shutdown_runs_on_lifespan_error(): + state = [] + @contextlib.asynccontextmanager + async def lifespan(app): + yield + raise RuntimeError('lifespan error') + app = FastHTML( + lifespan=lifespan, + on_shutdown=[lambda: state.append('shutdown')], + ) + cli = TestClient(app, raise_server_exceptions=False) + try: + with cli: pass + except RuntimeError: pass + assert state == ['shutdown']