Skip to content

Commit 8fc455c

Browse files
committed
Add BlockingPortal and enhance task management in Kernel class
1 parent 62d46b4 commit 8fc455c

File tree

2 files changed

+56
-8
lines changed

2 files changed

+56
-8
lines changed

ipykernel/kernelbase.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
to_thread,
5151
)
5252
from anyio.abc import TaskGroup, TaskStatus
53+
from anyio.from_thread import BlockingPortal
5354
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
5455
from IPython.core.error import StdinNotImplementedError
5556
from jupyter_client.session import Session
@@ -132,7 +133,8 @@ class Kernel(SingletonConfigurable):
132133
_send_exec_request: Dict[dict[zmq_anyio.Socket, MemoryObjectSendStream]] = Dict()
133134
_main_subshell_ready = Instance(Event, ())
134135
asyncio_event_loop = Instance(asyncio.AbstractEventLoop, allow_none=True, read_only=True) # type:ignore[call-overload]
135-
tg = Instance(TaskGroup, read_only=True)
136+
_tg_main = Instance(TaskGroup)
137+
_portal = Instance(BlockingPortal)
136138

137139
log: logging.Logger = Instance(logging.Logger, allow_none=True) # type:ignore[assignment]
138140

@@ -444,17 +446,22 @@ async def shell_main(self, subshell_id: str | None):
444446
await tg.start(socket.start)
445447
tg.start_soon(self._process_shell, socket)
446448
tg.start_soon(self._execute_request_loop, receive_stream)
447-
if subshell_id is None:
448-
# Main subshell.
449+
if not subshell_id:
450+
# Main subshell
449451
with contextlib.suppress(RuntimeError):
450452
self.set_trait("asyncio_event_loop", asyncio.get_running_loop())
451453
async with create_task_group() as tg_main:
452454
with CancelScope(shield=True) as scope:
453-
self.set_trait("tg", tg_main)
454-
self._main_subshell_ready.set()
455-
await to_thread.run_sync(self.shell_stop.wait)
456-
scope.cancel()
457-
tg.cancel_scope.cancel()
455+
self._tg_main = tg_main
456+
async with BlockingPortal() as portal:
457+
# Provide a portal for general threadsafe access
458+
self._portal = portal
459+
self._main_subshell_ready.set()
460+
await to_thread.run_sync(self.shell_stop.wait)
461+
await portal.stop(True)
462+
scope.cancel()
463+
tg_main.cancel_scope.cancel()
464+
tg.cancel_scope.cancel()
458465
self._send_exec_request.pop(socket, None)
459466
await send_stream.aclose()
460467
await receive_stream.aclose()
@@ -573,6 +580,17 @@ def pre_handler_hook(self):
573580
def post_handler_hook(self):
574581
"""Hook to execute after calling message handler"""
575582

583+
def start_soon(self, func, *args):
584+
"Run a coroutine in the main thread taskgroup."
585+
try:
586+
if self._portal._event_loop_thread_id == threading.get_ident():
587+
self._tg_main.start_soon(func, *args)
588+
else:
589+
self._portal.start_task_soon(func, *args)
590+
except Exception:
591+
self.log.exception("portal call failed")
592+
raise
593+
576594
async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
577595
"""Process messages on shell and control channels"""
578596
async with create_task_group() as tg:
@@ -604,6 +622,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
604622
def stop(self):
605623
self.shell_stop.set()
606624
self.control_stop.set()
625+
self._main_subshell_ready = Event()
607626

608627
def record_ports(self, ports):
609628
"""Record the ports that this kernel is using.

tests/test_ipkernel_direct.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,32 @@ async def test_do_debug_request(ipkernel: IPythonKernel) -> None:
177177
msg = ipkernel.session.msg("debug_request", {})
178178
ipkernel.session.serialize(msg)
179179
await ipkernel.do_debug_request(msg)
180+
181+
182+
@pytest.mark.parametrize("mode", ["main", "external"])
183+
@pytest.mark.parametrize("exception", [True, False])
184+
async def test_start_soon(mode, exception: bool, ipkernel: IPythonKernel, anyio_backend: str):
185+
# Test we can start coroutines from various scopes
186+
import anyio
187+
from anyio import to_thread
188+
189+
async def my_test(event: anyio.Event):
190+
event.set()
191+
if exception:
192+
raise ValueError
193+
194+
events = []
195+
196+
async def start():
197+
event = anyio.Event()
198+
if mode == "main":
199+
ipkernel.start_soon(my_test, event)
200+
else:
201+
await to_thread.run_sync(ipkernel.start_soon, my_test, event)
202+
events.append(event)
203+
204+
for _ in range(50):
205+
await start()
206+
207+
for event in events:
208+
await event.wait()

0 commit comments

Comments
 (0)