Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions fasthtml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,22 @@ def def_hdrs(htmx=True, surreal=True):
document.body.addEventListener('htmx:wsAfterMessage', sendmsg);
};"""))

# %% ../nbs/api/00_core.ipynb #95601256
def _wrap_lifespan(lifespan, on_startup, on_shutdown):
"Wrap a lifespan context manager with on_startup/on_shutdown callbacks."
on_startup,on_shutdown = listify(on_startup),listify(on_shutdown)
if not on_startup and not on_shutdown: return lifespan
@contextlib.asynccontextmanager
async def _lifespan(app):
for h in on_startup: await h() if inspect.iscoroutinefunction(h) else h()
try:
if lifespan:
async with lifespan(app) as state: yield state
else: yield
finally:
for h in on_shutdown: await h() if inspect.iscoroutinefunction(h) else h()
return _lifespan

# %% ../nbs/api/00_core.ipynb #3327a1e9
class FastHTML(Starlette):
def __init__(self, debug=False, routes=None, middleware=None, title: str = "FastHTML page", exception_handlers=None,
Expand All @@ -586,7 +602,7 @@ def __init__(self, debug=False, routes=None, middleware=None, title: str = "Fast
from IPython.display import display,HTML
if nb_hdrs: display(HTML(to_xml(tuple(hdrs))))
middleware.append(cors_allow)
on_startup,on_shutdown = listify(on_startup) or None,listify(on_shutdown) or None
lifespan = _wrap_lifespan(lifespan, on_startup, on_shutdown)
self.lifespan,self.hdrs,self.ftrs = lifespan,hdrs,ftrs
self.body_wrap,self.before,self.after,self.htmlkw,self.bodykw = body_wrap,before,after,htmlkw,bodykw
self.secret_key = get_key(secret_key, key_fname)
Expand All @@ -600,7 +616,7 @@ def __init__(self, debug=False, routes=None, middleware=None, title: str = "Fast
def _not_found(req, exc): return Response('404 Not Found', status_code=404)
exception_handlers[404] = _not_found
excs = {k:_wrap_ex(v, k, hdrs, ftrs, htmlkw, bodykw, body_wrap=body_wrap) for k,v in exception_handlers.items()}
super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, lifespan=lifespan)

# %% ../nbs/api/00_core.ipynb #dce68049
class HostRoute(Route):
Expand Down
28 changes: 26 additions & 2 deletions nbs/api/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,30 @@
" };\"\"\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95601256",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def _wrap_lifespan(lifespan, on_startup, on_shutdown):\n",
" \"Wrap a lifespan context manager with on_startup/on_shutdown callbacks.\"\n",
" on_startup,on_shutdown = listify(on_startup),listify(on_shutdown)\n",
" if not on_startup and not on_shutdown: return lifespan\n",
" @contextlib.asynccontextmanager\n",
" async def _lifespan(app):\n",
" for h in on_startup: await h() if inspect.iscoroutinefunction(h) else h()\n",
" try:\n",
" if lifespan:\n",
" async with lifespan(app) as state: yield state\n",
" else: yield\n",
" finally:\n",
" for h in on_shutdown: await h() if inspect.iscoroutinefunction(h) else h()\n",
" return _lifespan"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1742,7 +1766,7 @@
" from IPython.display import display,HTML\n",
" if nb_hdrs: display(HTML(to_xml(tuple(hdrs))))\n",
" middleware.append(cors_allow)\n",
" on_startup,on_shutdown = listify(on_startup) or None,listify(on_shutdown) or None\n",
" lifespan = _wrap_lifespan(lifespan, on_startup, on_shutdown)\n",
" self.lifespan,self.hdrs,self.ftrs = lifespan,hdrs,ftrs\n",
" self.body_wrap,self.before,self.after,self.htmlkw,self.bodykw = body_wrap,before,after,htmlkw,bodykw\n",
" self.secret_key = get_key(secret_key, key_fname)\n",
Expand All @@ -1756,7 +1780,7 @@
" def _not_found(req, exc): return Response('404 Not Found', status_code=404)\n",
" exception_handlers[404] = _not_found\n",
" excs = {k:_wrap_ex(v, k, hdrs, ftrs, htmlkw, bodykw, body_wrap=body_wrap) for k,v in exception_handlers.items()}\n",
" super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)"
" super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, lifespan=lifespan)"
]
},
{
Expand Down
76 changes: 76 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from fasthtml.common import *
from starlette.testclient import TestClient
import contextlib

# Test basic app creation works (previously crashed with Starlette 1.0)
def test_basic_app():
app, rt = fast_app()
cli = TestClient(app)
@rt('/')
def get(): return P('hello')
res = cli.get('/')
assert 'hello' in res.text

def test_on_startup_shutdown():
started, stopped = [], []
app = FastHTML(on_startup=[lambda: started.append(1)], on_shutdown=[lambda: stopped.append(1)])
cli = TestClient(app)
with cli:
assert started == [1]
assert stopped == [1]

def test_lifespan_only():
state = []
@contextlib.asynccontextmanager
async def lifespan(app):
state.append('started')
yield
state.append('stopped')
app = FastHTML(lifespan=lifespan)
cli = TestClient(app)
with cli:
assert state == ['started']
assert state == ['started', 'stopped']

def test_lifespan_with_startup_shutdown():
order = []
@contextlib.asynccontextmanager
async def lifespan(app):
order.append('lifespan_start')
yield
order.append('lifespan_stop')
app = FastHTML(
lifespan=lifespan,
on_startup=[lambda: order.append('on_startup')],
on_shutdown=[lambda: order.append('on_shutdown')],
)
cli = TestClient(app)
with cli:
assert order == ['on_startup', 'lifespan_start']
assert order == ['on_startup', 'lifespan_start', 'lifespan_stop', 'on_shutdown']

def test_async_startup_shutdown():
state = []
async def astart(): state.append('async_start')
async def astop(): state.append('async_stop')
app = FastHTML(on_startup=[astart], on_shutdown=[astop])
cli = TestClient(app)
with cli:
assert state == ['async_start']
assert state == ['async_start', 'async_stop']

def test_shutdown_runs_on_lifespan_error():
state = []
@contextlib.asynccontextmanager
async def lifespan(app):
yield
raise RuntimeError('lifespan error')
app = FastHTML(
lifespan=lifespan,
on_shutdown=[lambda: state.append('shutdown')],
)
cli = TestClient(app, raise_server_exceptions=False)
try:
with cli: pass
except RuntimeError: pass
assert state == ['shutdown']