Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
"pythonjob.jsonable_data" = "aiida_pythonjob.data.jsonable_data:JsonableData"
"pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData"
"pythonjob.builtins.NoneType" = "aiida_pythonjob.data.common_data:NoneData"
"pythonjob.builtins.function" = "aiida_pythonjob.data.common_data:FunctionData"
"pythonjob.datetime.datetime" = "aiida_pythonjob.data.common_data:DateTimeData"

[project.entry-points."aiida.calculations"]
Expand Down
10 changes: 8 additions & 2 deletions src/aiida_pythonjob/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .common_data import DateTimeData
from .common_data import DateTimeData, FunctionData
from .pickled_data import PickledData
from .serializer import general_serializer, serialize_to_aiida_nodes

__all__ = ("DateTimeData", "PickledData", "general_serializer", "serialize_to_aiida_nodes")
__all__ = (
"DateTimeData",
"FunctionData",
"PickledData",
"general_serializer",
"serialize_to_aiida_nodes",
)
48 changes: 48 additions & 0 deletions src/aiida_pythonjob/data/common_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,51 @@ def value(self) -> datetime.datetime:

def __str__(self):
return str(self.value)


class FunctionData(Data):
"""AiiDA node to store a Python function path."""

def __init__(self, value, **kwargs):
module = getattr(value, "__module__", None)
qualname = getattr(value, "__qualname__", None) or getattr(value, "__name__", None)
if not module or not qualname:
raise TypeError(f"Expected a function-like object, got {type(value)}")
super().__init__(**kwargs)
self.base.attributes.set("module_path", module)
self.base.attributes.set("qualname", qualname)

@property
def module_path(self) -> str:
return self.base.attributes.get("module_path")

@property
def qualname(self) -> str:
return self.base.attributes.get("qualname")

@property
def path(self) -> str:
return f"{self.module_path}:{self.qualname}"

@property
def value(self):
from importlib import import_module

try:
module = import_module(self.module_path)
except Exception as exc:
raise ImportError(
f"Failed to import function module '{self.module_path}' for FunctionData '{self.path}': {exc}"
) from exc

obj = module
try:
for part in self.qualname.split("."):
obj = getattr(obj, part)
except AttributeError as exc:
raise ImportError(f"Failed to resolve function '{self.path}': attribute '{part}' not found.") from exc

return obj

def __str__(self) -> str:
return self.path
2 changes: 1 addition & 1 deletion src/aiida_pythonjob/data/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def general_serializer(
serializers = serializers or all_serializers

# 1) If it is already an AiiDA node, just return it
if isinstance(data, orm.Data):
if isinstance(data, orm.Node):
return data
elif isinstance(data, common.extendeddicts.AttributeDict):
# if the data is an AttributeDict, use it directly
Expand Down
12 changes: 12 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def test_datetime_data():
DateTimeData("2024-06-01")


def _sample_function():
return "ok"


def test_function_data():
from aiida_pythonjob.data.common_data import FunctionData

func_data = FunctionData(_sample_function)
assert func_data.path.endswith(":_sample_function")
assert func_data.value is _sample_function


def test_jsonable_data_pydantic_and_dataclass():
pytest.importorskip("pydantic")

Expand Down