Skip to content

Commit 1888d2f

Browse files
committed
create Problem for sciml extension
1 parent 1dabbc1 commit 1888d2f

1 file changed

Lines changed: 76 additions & 3 deletions

File tree

petab/v2/core.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ def n_estimated(self) -> int:
11091109

11101110

11111111
class Hybridization(BaseModel):
1112-
"""Assigns NN inputs and outputs."""
1112+
"""Assigns PEtab SciML NN inputs and outputs."""
11131113

11141114
#: The target ID.
11151115
target_id: str = Field(alias=C.TARGET_ID)
@@ -1136,7 +1136,7 @@ def _sympify(cls, v):
11361136

11371137

11381138
class HybridizationTable(BaseTable[Hybridization]):
1139-
"""PEtab hybridization table."""
1139+
"""PEtab SciML hybridization table."""
11401140

11411141
@property
11421142
def hybridizations(self) -> list[Hybridization]:
@@ -1153,7 +1153,7 @@ def from_df(cls, df: pd.DataFrame, **kwargs) -> HybridizationTable:
11531153
Hybridization(
11541154
**row.to_dict(),
11551155
)
1156-
for _, row in df.reset_index().iterrows()
1156+
for _, row in df.iterrows()
11571157
]
11581158

11591159
return cls(hybridizations, **kwargs)
@@ -1205,6 +1205,9 @@ def __init__(
12051205
measurement_tables: list[MeasurementTable] = None,
12061206
parameter_tables: list[ParameterTable] = None,
12071207
mapping_tables: list[MappingTable] = None,
1208+
neural_networks: list[NNModel] | None = None,
1209+
hybridization_tables: list[HybridizationTable] | None = None,
1210+
array_data_files: list[ArrayData] | None = None,
12081211
config: ProblemConfig = None,
12091212
):
12101213
from ..v2.lint import default_validation_tasks
@@ -1221,6 +1224,11 @@ def __init__(
12211224
self.measurement_tables = measurement_tables or [MeasurementTable()]
12221225
self.mapping_tables = mapping_tables or [MappingTable()]
12231226
self.parameter_tables = parameter_tables or [ParameterTable()]
1227+
self.neural_networks = neural_networks or []
1228+
self.hybridization_tables = hybridization_tables or [
1229+
HybridizationTable()
1230+
]
1231+
self.array_data_files = array_data_files or []
12241232

12251233
def __repr__(self):
12261234
return f"<{self.__class__.__name__} id={self.id!r}>"
@@ -1393,6 +1401,59 @@ def from_yaml(
13931401
else None
13941402
)
13951403

1404+
# sciml extension
1405+
if config.extensions and config.extensions[C.SCIML]:
1406+
try:
1407+
from petab_sciml import (
1408+
ArrayDataStandard,
1409+
NNModel,
1410+
NNModelStandard,
1411+
)
1412+
except ImportError as e:
1413+
raise ImportError(
1414+
"To generate a PEtab SciML problem, (petab_sciml) must be"
1415+
"installed."
1416+
) from e
1417+
1418+
# Neural network classes are constructed via pytorch for now to get the
1419+
# proper inputs
1420+
neural_networks = (
1421+
[
1422+
NNModel.from_pytorch_module(
1423+
NNModelStandard.load_data(
1424+
_generate_path(
1425+
file_path=nn_config.location,
1426+
base_path=base_path,
1427+
)
1428+
).to_pytorch_module(),
1429+
nn_model_id=nn_id,
1430+
)
1431+
for nn_id, nn_config in (
1432+
config.extensions[C.SCIML].neural_nets or {}
1433+
).items()
1434+
]
1435+
if config.extensions and config.extensions[C.SCIML]
1436+
else None
1437+
)
1438+
1439+
hybridization_tables = (
1440+
[
1441+
HybridizationTable.from_tsv(f, base_path)
1442+
for f in config.extensions[C.SCIML].hybridization_files
1443+
]
1444+
if config.extensions and config.extensions[C.SCIML]
1445+
else None
1446+
)
1447+
1448+
array_data_files = (
1449+
[
1450+
ArrayDataStandard.load_data(_generate_path(f, base_path))
1451+
for f in config.extensions[C.SCIML].array_files
1452+
]
1453+
if config.extensions and config.extensions[C.SCIML]
1454+
else None
1455+
)
1456+
13961457
return Problem(
13971458
config=config,
13981459
models=models,
@@ -1402,6 +1463,9 @@ def from_yaml(
14021463
measurement_tables=measurement_tables,
14031464
parameter_tables=parameter_tables,
14041465
mapping_tables=mapping_tables,
1466+
neural_networks=neural_networks,
1467+
hybridization_tables=hybridization_tables,
1468+
array_data_files=array_data_files,
14051469
)
14061470

14071471
@staticmethod
@@ -1708,6 +1772,15 @@ def id(self, value: str):
17081772
self.config = ProblemConfig(format_version="2.0.0")
17091773
self.config.id = value
17101774

1775+
@property
1776+
def hybridizations(self) -> list[Hybridization]:
1777+
"""List of hybridizations in the hybridization table(s)."""
1778+
return list(
1779+
chain.from_iterable(
1780+
ht.hybridizations for ht in self.hybridization_tables
1781+
)
1782+
)
1783+
17111784
def get_optimization_parameters(self) -> list[str]:
17121785
"""
17131786
Get the list of optimization parameter IDs from parameter table.

0 commit comments

Comments
 (0)