Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ requires-python = '>=3.9'
dependencies = [
"greenback~=1.0",
"greenlet>=3.2.4", # We need to expicitly set our lowest bond, unfortunatly greenback does not pin the exact version
"kiwipy[rmq]~=0.9.0",
"kiwipy @ git+https://github.com/aiidateam/kiwipy@nowait",
"pyyaml~=6.0",
]
[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def add_broadcast_subscriber(
def remove_broadcast_subscriber(self, identifier: 'ID_TYPE') -> None:
return self._communicator.remove_broadcast_subscriber(identifier)

def task_send(self, task: Any, no_reply: bool = False) -> kiwipy.Future:
return self._communicator.task_send(task, no_reply)
def task_send(self, task: Any, no_reply: bool = False, nowait: bool = False) -> kiwipy.Future:
return self._communicator.task_send(task, no_reply, nowait)

def rpc_send(self, recipient_id: 'ID_TYPE', msg: Any) -> kiwipy.Future:
return self._communicator.rpc_send(recipient_id, msg)
Expand Down
63 changes: 34 additions & 29 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

import kiwipy
from kiwipy.rmq import TaskResult

from . import communications, futures, loaders, persistence
from .utils import PID_TYPE
Expand All @@ -17,6 +18,7 @@
'ProcessLauncher',
'RemoteProcessController',
'RemoteProcessThreadController',
'TaskResult',
'create_continue_body',
'create_launch_body',
]
Expand Down Expand Up @@ -260,7 +262,7 @@ async def continue_process(
"""
message = create_continue_body(pid=pid, tag=tag, nowait=nowait)
# Wait for the communication to go through
continue_future = self._communicator.task_send(message, no_reply=no_reply)
continue_future = self._communicator.task_send(message, no_reply=no_reply, nowait=nowait)
future = await asyncio.wrap_future(continue_future)

if no_reply:
Expand Down Expand Up @@ -294,7 +296,7 @@ async def launch_process(
"""

message = create_launch_body(process_class, init_args, init_kwargs, persist, loader, nowait)
launch_future = self._communicator.task_send(message, no_reply=no_reply)
launch_future = self._communicator.task_send(message, no_reply=no_reply, nowait=nowait)
future = await asyncio.wrap_future(launch_future)

if no_reply:
Expand Down Expand Up @@ -333,7 +335,7 @@ async def execute_process(
pid: 'PID_TYPE' = await asyncio.wrap_future(future)

message = create_continue_body(pid, nowait=nowait)
continue_future = self._communicator.task_send(message, no_reply=no_reply)
continue_future = self._communicator.task_send(message, no_reply=no_reply, nowait=nowait)
future = await asyncio.wrap_future(continue_future)

if no_reply:
Expand Down Expand Up @@ -428,7 +430,7 @@ def continue_process(
self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False
) -> Union[None, PID_TYPE, ProcessResult]:
message = create_continue_body(pid=pid, tag=tag, nowait=nowait)
return self.task_send(message, no_reply=no_reply)
return self.task_send(message, no_reply=no_reply, nowait=nowait)

def launch_process(
self,
Expand All @@ -453,7 +455,7 @@ def launch_process(
:return: the pid of the created process or the outputs (if nowait=False)
"""
message = create_launch_body(process_class, init_args, init_kwargs, persist, loader, nowait)
return self.task_send(message, no_reply=no_reply)
return self.task_send(message, no_reply=no_reply, nowait=nowait)

def execute_process(
self,
Expand Down Expand Up @@ -492,15 +494,16 @@ def on_created(_: Any) -> None:
create_future.add_done_callback(on_created)
return execute_future

def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]:
def task_send(self, message: Any, no_reply: bool = False, nowait: bool = False) -> Optional[Any]:
"""
Send a task to be performed using the communicator

:param message: the task message
:param no_reply: if True, this call will be fire-and-forget, i.e. no return value
:param nowait: if True, return immediately with task_id instead of waiting for result
:return: the response from the remote side (if no_reply=False)
"""
return self._communicator.task_send(message, no_reply=no_reply)
return self._communicator.task_send(message, no_reply=no_reply, nowait=nowait)


class ProcessLauncher:
Expand Down Expand Up @@ -545,10 +548,15 @@ def __init__(
else:
self._loader = loaders.get_object_loader()

async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, ProcessResult]:
async def __call__(
self, communicator: kiwipy.Communicator, task: Dict[str, Any]
) -> Union[TaskResult, 'PID_TYPE']:
"""
Receive a task.
:param task: The task message

:param communicator: the communicator
:param task: the task message
:return: TaskResult for launch/continue tasks, PID for create tasks
"""
task_type = task[TASK_KEY]
if task_type == LAUNCH_TASK:
Expand All @@ -568,7 +576,7 @@ async def _launch(
nowait: bool,
init_args: Optional[Sequence[Any]] = None,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[PID_TYPE, ProcessResult]:
) -> TaskResult:
"""
Launch the process

Expand All @@ -578,7 +586,7 @@ async def _launch(
:param nowait: if True only return when the process finishes
:param init_args: positional arguments to the process constructor
:param init_kwargs: keyword arguments to the process constructor
:return: the pid of the created process or the outputs (if nowait=False)
:return: TaskResult with task_id (PID) and result Future
"""
if persist and not self._persister:
raise communications.TaskRejected('Cannot persist process, no persister')
Expand All @@ -593,25 +601,26 @@ async def _launch(
if persist and self._persister is not None:
self._persister.save_checkpoint(proc)

if nowait:
# XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails
asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006
return proc.pid

await proc.step_until_terminated()

return proc.future().result()
# Start process in background, return TaskResult with Future
# kiwipy will handle early reply (if nowait) and wait for Future before acking
asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006
return TaskResult(task_id=proc.pid, result=communications.plum_to_kiwi_future(proc.future()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is proc.future().result() actually again, is it always pid?


async def _continue(
self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None
) -> Union[PID_TYPE, ProcessResult]:
self,
_communicator: kiwipy.Communicator,
pid: 'PID_TYPE',
nowait: bool,
tag: Optional[str] = None,
) -> TaskResult:
"""
Continue the process

:param _communicator: the communicator
:param pid: the pid of the process to continue
:param nowait: if True don't wait for the process to complete
:param tag: the checkpoint tag to continue from
:return: TaskResult with task_id (PID) and result Future
"""
if not self._persister:
LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid)
Expand All @@ -621,14 +630,10 @@ async def _continue(
saved_state = self._persister.load_checkpoint(pid, tag)
proc = cast('Process', saved_state.unbundle(self._load_context))

if nowait:
# XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails
asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006
return proc.pid

await proc.step_until_terminated()

return proc.future().result()
# Start process in background, return TaskResult with Future
# kiwipy will handle early reply (if nowait) and wait for Future before acking
asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006
return TaskResult(task_id=proc.pid, result=communications.plum_to_kiwi_future(proc.future()))

async def _create(
self,
Expand Down
10 changes: 8 additions & 2 deletions tests/test_process_comms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
import asyncio

import pytest

import plumpy
Expand Down Expand Up @@ -37,7 +39,9 @@ async def test_continue():
del process
process = None

result = await launcher._continue(None, **plumpy.create_continue_body(pid)[process_comms.TASK_ARGS])
task_result = await launcher._continue(None, **plumpy.create_continue_body(pid)[process_comms.TASK_ARGS])
# _continue returns a TaskResult; wait for the result Future to resolve
result = await asyncio.wrap_future(task_result.result)
assert result == utils.DummyProcess.EXPECTED_OUTPUTS


Expand All @@ -51,5 +55,7 @@ async def test_loader_is_used():
launcher = plumpy.ProcessLauncher(persister=persister, loader=loader)

continue_task = plumpy.create_continue_body(proc.pid)
result = await launcher._continue(None, **continue_task[process_comms.TASK_ARGS])
task_result = await launcher._continue(None, **continue_task[process_comms.TASK_ARGS])
# _continue returns a TaskResult; wait for the result Future to resolve
result = await asyncio.wrap_future(task_result.result)
assert result == utils.DummyProcess.EXPECTED_OUTPUTS
Loading