|
50 | 50 | to_thread, |
51 | 51 | ) |
52 | 52 | from anyio.abc import TaskGroup, TaskStatus |
| 53 | +from anyio.from_thread import BlockingPortal |
53 | 54 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
54 | 55 | from IPython.core.error import StdinNotImplementedError |
55 | 56 | from jupyter_client.session import Session |
@@ -132,7 +133,8 @@ class Kernel(SingletonConfigurable): |
132 | 133 | _send_exec_request: Dict[dict[zmq_anyio.Socket, MemoryObjectSendStream]] = Dict() |
133 | 134 | _main_subshell_ready = Instance(Event, ()) |
134 | 135 | asyncio_event_loop = Instance(asyncio.AbstractEventLoop, allow_none=True, read_only=True) # type:ignore[call-overload] |
135 | | - tg = Instance(TaskGroup, read_only=True) |
| 136 | + _tg_main = Instance(TaskGroup) |
| 137 | + _portal = Instance(BlockingPortal) |
136 | 138 |
|
137 | 139 | log: logging.Logger = Instance(logging.Logger, allow_none=True) # type:ignore[assignment] |
138 | 140 |
|
@@ -444,17 +446,22 @@ async def shell_main(self, subshell_id: str | None): |
444 | 446 | await tg.start(socket.start) |
445 | 447 | tg.start_soon(self._process_shell, socket) |
446 | 448 | tg.start_soon(self._execute_request_loop, receive_stream) |
447 | | - if subshell_id is None: |
448 | | - # Main subshell. |
| 449 | + if not subshell_id: |
| 450 | + # Main subshell |
449 | 451 | with contextlib.suppress(RuntimeError): |
450 | 452 | self.set_trait("asyncio_event_loop", asyncio.get_running_loop()) |
451 | 453 | async with create_task_group() as tg_main: |
452 | 454 | with CancelScope(shield=True) as scope: |
453 | | - self.set_trait("tg", tg_main) |
454 | | - self._main_subshell_ready.set() |
455 | | - await to_thread.run_sync(self.shell_stop.wait) |
456 | | - scope.cancel() |
457 | | - tg.cancel_scope.cancel() |
| 455 | + self._tg_main = tg_main |
| 456 | + async with BlockingPortal() as portal: |
| 457 | + # Provide a portal for general threadsafe access |
| 458 | + self._portal = portal |
| 459 | + self._main_subshell_ready.set() |
| 460 | + await to_thread.run_sync(self.shell_stop.wait) |
| 461 | + await portal.stop(True) |
| 462 | + scope.cancel() |
| 463 | + tg_main.cancel_scope.cancel() |
| 464 | + tg.cancel_scope.cancel() |
458 | 465 | self._send_exec_request.pop(socket, None) |
459 | 466 | await send_stream.aclose() |
460 | 467 | await receive_stream.aclose() |
@@ -573,6 +580,17 @@ def pre_handler_hook(self): |
573 | 580 | def post_handler_hook(self): |
574 | 581 | """Hook to execute after calling message handler""" |
575 | 582 |
|
| 583 | + def start_soon(self, func, *args): |
| 584 | + "Run a coroutine in the main thread taskgroup." |
| 585 | + try: |
| 586 | + if self._portal._event_loop_thread_id == threading.get_ident(): |
| 587 | + self._tg_main.start_soon(func, *args) |
| 588 | + else: |
| 589 | + self._portal.start_task_soon(func, *args) |
| 590 | + except Exception: |
| 591 | + self.log.exception("portal call failed") |
| 592 | + raise |
| 593 | + |
576 | 594 | async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: |
577 | 595 | """Process messages on shell and control channels""" |
578 | 596 | async with create_task_group() as tg: |
@@ -604,6 +622,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: |
604 | 622 | def stop(self): |
605 | 623 | self.shell_stop.set() |
606 | 624 | self.control_stop.set() |
| 625 | + self._main_subshell_ready = Event() |
607 | 626 |
|
608 | 627 | def record_ports(self, ports): |
609 | 628 | """Record the ports that this kernel is using. |
|
0 commit comments