diff --git a/fasthtml/core.py b/fasthtml/core.py
index f7bc5e69..3c194976 100644
--- a/fasthtml/core.py
+++ b/fasthtml/core.py
@@ -566,6 +566,22 @@ def def_hdrs(htmx=True, surreal=True):
document.body.addEventListener('htmx:wsAfterMessage', sendmsg);
};"""))
+# %% ../nbs/api/00_core.ipynb #95601256
+def _wrap_lifespan(lifespan, on_startup, on_shutdown):
+ "Wrap a lifespan context manager with on_startup/on_shutdown callbacks."
+ on_startup,on_shutdown = listify(on_startup),listify(on_shutdown)
+ if not on_startup and not on_shutdown: return lifespan
+ @contextlib.asynccontextmanager
+ async def _lifespan(app):
+ for h in on_startup: await h() if inspect.iscoroutinefunction(h) else h()
+ try:
+ if lifespan:
+ async with lifespan(app) as state: yield state
+ else: yield
+ finally:
+ for h in on_shutdown: await h() if inspect.iscoroutinefunction(h) else h()
+ return _lifespan
+
# %% ../nbs/api/00_core.ipynb #3327a1e9
class FastHTML(Starlette):
def __init__(self, debug=False, routes=None, middleware=None, title: str = "FastHTML page", exception_handlers=None,
@@ -586,7 +602,7 @@ def __init__(self, debug=False, routes=None, middleware=None, title: str = "Fast
from IPython.display import display,HTML
if nb_hdrs: display(HTML(to_xml(tuple(hdrs))))
middleware.append(cors_allow)
- on_startup,on_shutdown = listify(on_startup) or None,listify(on_shutdown) or None
+ lifespan = _wrap_lifespan(lifespan, on_startup, on_shutdown)
self.lifespan,self.hdrs,self.ftrs = lifespan,hdrs,ftrs
self.body_wrap,self.before,self.after,self.htmlkw,self.bodykw = body_wrap,before,after,htmlkw,bodykw
self.secret_key = get_key(secret_key, key_fname)
@@ -600,7 +616,7 @@ def __init__(self, debug=False, routes=None, middleware=None, title: str = "Fast
def _not_found(req, exc): return Response('404 Not Found', status_code=404)
exception_handlers[404] = _not_found
excs = {k:_wrap_ex(v, k, hdrs, ftrs, htmlkw, bodykw, body_wrap=body_wrap) for k,v in exception_handlers.items()}
- super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
+ super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, lifespan=lifespan)
# %% ../nbs/api/00_core.ipynb #dce68049
class HostRoute(Route):
diff --git a/nbs/api/00_core.ipynb b/nbs/api/00_core.ipynb
index b75bc22c..379fa4c8 100644
--- a/nbs/api/00_core.ipynb
+++ b/nbs/api/00_core.ipynb
@@ -1715,6 +1715,30 @@
" };\"\"\"))"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "95601256",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "def _wrap_lifespan(lifespan, on_startup, on_shutdown):\n",
+ " \"Wrap a lifespan context manager with on_startup/on_shutdown callbacks.\"\n",
+ " on_startup,on_shutdown = listify(on_startup),listify(on_shutdown)\n",
+ " if not on_startup and not on_shutdown: return lifespan\n",
+ " @contextlib.asynccontextmanager\n",
+ " async def _lifespan(app):\n",
+ " for h in on_startup: await h() if inspect.iscoroutinefunction(h) else h()\n",
+ " try:\n",
+ " if lifespan:\n",
+ " async with lifespan(app) as state: yield state\n",
+ " else: yield\n",
+ " finally:\n",
+ " for h in on_shutdown: await h() if inspect.iscoroutinefunction(h) else h()\n",
+ " return _lifespan"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -1742,7 +1766,7 @@
" from IPython.display import display,HTML\n",
" if nb_hdrs: display(HTML(to_xml(tuple(hdrs))))\n",
" middleware.append(cors_allow)\n",
- " on_startup,on_shutdown = listify(on_startup) or None,listify(on_shutdown) or None\n",
+ " lifespan = _wrap_lifespan(lifespan, on_startup, on_shutdown)\n",
" self.lifespan,self.hdrs,self.ftrs = lifespan,hdrs,ftrs\n",
" self.body_wrap,self.before,self.after,self.htmlkw,self.bodykw = body_wrap,before,after,htmlkw,bodykw\n",
" self.secret_key = get_key(secret_key, key_fname)\n",
@@ -1756,7 +1780,7 @@
" def _not_found(req, exc): return Response('404 Not Found', status_code=404)\n",
" exception_handlers[404] = _not_found\n",
" excs = {k:_wrap_ex(v, k, hdrs, ftrs, htmlkw, bodykw, body_wrap=body_wrap) for k,v in exception_handlers.items()}\n",
- " super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)"
+ " super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, lifespan=lifespan)"
]
},
{
diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py
new file mode 100644
index 00000000..047ce67c
--- /dev/null
+++ b/tests/test_lifespan.py
@@ -0,0 +1,76 @@
+from fasthtml.common import *
+from starlette.testclient import TestClient
+import contextlib
+
+# Test basic app creation works (previously crashed with Starlette 1.0)
+def test_basic_app():
+ app, rt = fast_app()
+ cli = TestClient(app)
+ @rt('/')
+ def get(): return P('hello')
+ res = cli.get('/')
+ assert 'hello' in res.text
+
+def test_on_startup_shutdown():
+ started, stopped = [], []
+ app = FastHTML(on_startup=[lambda: started.append(1)], on_shutdown=[lambda: stopped.append(1)])
+ cli = TestClient(app)
+ with cli:
+ assert started == [1]
+ assert stopped == [1]
+
+def test_lifespan_only():
+ state = []
+ @contextlib.asynccontextmanager
+ async def lifespan(app):
+ state.append('started')
+ yield
+ state.append('stopped')
+ app = FastHTML(lifespan=lifespan)
+ cli = TestClient(app)
+ with cli:
+ assert state == ['started']
+ assert state == ['started', 'stopped']
+
+def test_lifespan_with_startup_shutdown():
+ order = []
+ @contextlib.asynccontextmanager
+ async def lifespan(app):
+ order.append('lifespan_start')
+ yield
+ order.append('lifespan_stop')
+ app = FastHTML(
+ lifespan=lifespan,
+ on_startup=[lambda: order.append('on_startup')],
+ on_shutdown=[lambda: order.append('on_shutdown')],
+ )
+ cli = TestClient(app)
+ with cli:
+ assert order == ['on_startup', 'lifespan_start']
+ assert order == ['on_startup', 'lifespan_start', 'lifespan_stop', 'on_shutdown']
+
+def test_async_startup_shutdown():
+ state = []
+ async def astart(): state.append('async_start')
+ async def astop(): state.append('async_stop')
+ app = FastHTML(on_startup=[astart], on_shutdown=[astop])
+ cli = TestClient(app)
+ with cli:
+ assert state == ['async_start']
+ assert state == ['async_start', 'async_stop']
+
+def test_shutdown_runs_on_lifespan_error():
+ state = []
+ @contextlib.asynccontextmanager
+ async def lifespan(app):
+ yield
+ raise RuntimeError('lifespan error')
+ app = FastHTML(
+ lifespan=lifespan,
+ on_shutdown=[lambda: state.append('shutdown')],
+ )
+ cli = TestClient(app, raise_server_exceptions=False)
+ try:
+ with cli: pass
+ except RuntimeError: pass
+ assert state == ['shutdown']