Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions src/aiida_pythonjob/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@
import typing as t
from typing import List

from aiida.engine.processes.functions import FunctionType, get_stack_size
import aiida
from aiida.engine.processes.functions import FunctionType
from aiida.manage import get_manager
from aiida.orm import ProcessNode
from node_graph.socket_spec import SocketSpec
from packaging.version import parse as parse_version

from aiida_pythonjob.calculations.pyfunction import PyFunction
from aiida_pythonjob.launch import create_inputs, prepare_pyfunction_inputs

LOGGER = logging.getLogger(__name__)

_AIIDA_VERSION = parse_version(aiida.__version__)
_NEEDS_RECURSION_LIMIT_WORKAROUND = _AIIDA_VERSION < parse_version("2.8.0")

if _NEEDS_RECURSION_LIMIT_WORKAROUND:
from aiida.engine.processes.functions import get_stack_size


# The following code is modified from the aiida-core.engine.processes.functions module
def pyfunction(
Expand All @@ -42,25 +50,26 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, "ProcessNode
:param kwargs: input keyword arguments to construct the FunctionProcess
:return: tuple of the outputs of the process and the process node
"""
frame_delta = 1000
frame_count = get_stack_size()
stack_limit = sys.getrecursionlimit()
LOGGER.info("Executing process function, current stack status: %d frames of %d", frame_count, stack_limit)

# If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the
# stack limit by ``frame_delta``.
if frame_count > min(0.8 * stack_limit, stack_limit - 200):
LOGGER.warning(
"Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d",
frame_count,
stack_limit,
frame_delta,
if _NEEDS_RECURSION_LIMIT_WORKAROUND:
frame_delta = 1000
frame_count = get_stack_size()
stack_limit = sys.getrecursionlimit()
LOGGER.info(
"Executing process function, current stack status: %d frames of %d", frame_count, stack_limit
)
sys.setrecursionlimit(stack_limit + frame_delta)

if frame_count > min(0.8 * stack_limit, stack_limit - 200):
LOGGER.warning(
"Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d",
frame_count,
stack_limit,
frame_delta,
)
sys.setrecursionlimit(stack_limit + frame_delta)

manager = get_manager()
runner = manager.get_runner()
# # Remove all the known inputs from the kwargs
# Remove all the known inputs from the kwargs
outputs_spec = kwargs.pop("outputs_spec", None) or outputs
inputs_spec = kwargs.pop("inputs_spec", None) or inputs
metadata = kwargs.pop("metadata", None)
Expand Down