Skip to content

Commit 1dabbc1

Browse files
committed
add sciml problem config
1 parent 7e72d12 commit 1dabbc1

2 files changed

Lines changed: 73 additions & 6 deletions

File tree

petab/v2/C.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@
258258
MAPPING_FILES = "mapping_files"
259259
#: Extensions key in the YAML file
260260
EXTENSIONS = "extensions"
261+
#: PEtab SciML extension
262+
SCIML = "sciml"
261263

262264

263265
# MAPPING

petab/v2/core.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,14 +2012,21 @@ def validate(
20122012

20132013
validation_results = ValidationResultList()
20142014

2015-
if self.config and self.config.extensions:
2016-
extensions = ",".join(self.config.extensions.keys())
2015+
supported_extensions = {C.SCIML}
2016+
if (
2017+
self.config
2018+
and self.config.extensions
2019+
and (self.config.extensions.keys() - supported_extensions)
2020+
):
2021+
extensions_without_support = ",".join(
2022+
self.config.extensions.keys() - supported_extensions
2023+
)
20172024
validation_results.append(
20182025
ValidationIssue(
20192026
ValidationIssueSeverity.WARNING,
2020-
"Validation of PEtab extensions is not yet implemented, "
2021-
"but the given problem uses the following extensions: "
2022-
f"{extensions}",
2027+
"The given problem uses the following extensions for "
2028+
"which validation is not yet implemented: "
2029+
f"{extensions_without_support}",
20232030
)
20242031
)
20252032

@@ -2521,13 +2528,44 @@ class ModelFile(BaseModel):
25212528
)
25222529

25232530

2531+
class NeuralNetConfig(BaseModel):
2532+
"""A neural net in the PEtab SciML problem configuration."""
2533+
2534+
location: AnyUrl | Path
2535+
pre_initialization: bool
2536+
format: str
2537+
2538+
model_config = ConfigDict(
2539+
validate_assignment=True,
2540+
)
2541+
2542+
25242543
class ExtensionConfig(BaseModel):
25252544
"""The configuration of a PEtab extension."""
25262545

25272546
version: str
25282547
config: dict
25292548

25302549

2550+
class SciMLConfig(BaseModel):
2551+
"""The extended configuration of a PEtab SciML problem."""
2552+
2553+
#: The PEtab SciML format version.
2554+
version: str = "0.1.0"
2555+
#: The paths to the array data files.
2556+
# Absolute or relative to `base_path`.
2557+
array_files: list[AnyUrl | Path] = []
2558+
#: The paths to the hybridization tables.
2559+
# Absolute or relative to `base_path`.
2560+
hybridization_files: list[AnyUrl | Path] = []
2561+
#: The neural network IDs and info.
2562+
neural_nets: dict[str, NeuralNetConfig] | None = {}
2563+
2564+
model_config = ConfigDict(
2565+
validate_assignment=True,
2566+
)
2567+
2568+
25312569
class ProblemConfig(BaseModel):
25322570
"""The PEtab problem configuration."""
25332571

@@ -2577,6 +2615,23 @@ class ProblemConfig(BaseModel):
25772615
validate_assignment=True,
25782616
)
25792617

2618+
@field_validator("extensions", mode="before")
2619+
@classmethod
2620+
def _parse_extensions(cls, v):
2621+
"""Parse extensions dict and convert known extensions to their specific
2622+
config classes."""
2623+
if isinstance(v, dict):
2624+
parsed_extensions = {}
2625+
for ext_name, ext_config in v.items():
2626+
if ext_name == C.SCIML:
2627+
# Convert sciml extension to SciMLConfig
2628+
parsed_extensions[ext_name] = SciMLConfig(**ext_config)
2629+
else:
2630+
# Keep other extensions as ExtensionConfig
2631+
parsed_extensions[ext_name] = ExtensionConfig(**ext_config)
2632+
return parsed_extensions
2633+
return v
2634+
25802635
# convert parameter_file to list
25812636
@field_validator(
25822637
"parameter_files",
@@ -2614,12 +2669,22 @@ def to_yaml(self, filename: str | Path):
26142669

26152670
for model_id in data.get("model_files", {}):
26162671
data["model_files"][model_id][C.MODEL_LOCATION] = str(
2617-
data["model_files"][model_id]["location"]
2672+
data["model_files"][model_id][C.MODEL_LOCATION]
26182673
)
26192674
if data["id"] is None:
26202675
# The schema requires a valid id or no id field at all.
26212676
del data["id"]
26222677

2678+
for ext_id, d_ext in data[C.EXTENSIONS].items():
2679+
if ext_id == C.SCIML:
2680+
# convert Paths to strings
2681+
for key in ("array_files", "hybridization_files"):
2682+
d_ext[key] = list(map(str, data[key]))
2683+
for nn in d_ext["neural_nets"]:
2684+
d_ext["neural_nets"][nn][C.MODEL_LOCATION] = str(
2685+
d_ext["neural_nets"][nn][C.MODEL_LOCATION]
2686+
)
2687+
26232688
write_yaml(data, filename)
26242689

26252690
@property

0 commit comments

Comments
 (0)