@@ -1109,7 +1109,7 @@ def n_estimated(self) -> int:
11091109
11101110
11111111class Hybridization (BaseModel ):
1112- """Assigns NN inputs and outputs."""
1112+ """Assigns PEtab SciML NN inputs and outputs."""
11131113
11141114 #: The target ID.
11151115 target_id : str = Field (alias = C .TARGET_ID )
@@ -1136,7 +1136,7 @@ def _sympify(cls, v):
11361136
11371137
11381138class HybridizationTable (BaseTable [Hybridization ]):
1139- """PEtab hybridization table."""
1139+ """PEtab SciML hybridization table."""
11401140
11411141 @property
11421142 def hybridizations (self ) -> list [Hybridization ]:
@@ -1153,7 +1153,7 @@ def from_df(cls, df: pd.DataFrame, **kwargs) -> HybridizationTable:
11531153 Hybridization (
11541154 ** row .to_dict (),
11551155 )
1156- for _ , row in df .reset_index (). iterrows ()
1156+ for _ , row in df .iterrows ()
11571157 ]
11581158
11591159 return cls (hybridizations , ** kwargs )
@@ -1205,6 +1205,9 @@ def __init__(
12051205 measurement_tables : list [MeasurementTable ] = None ,
12061206 parameter_tables : list [ParameterTable ] = None ,
12071207 mapping_tables : list [MappingTable ] = None ,
1208+ neural_networks : list [NNModel ] | None = None ,
1209+ hybridization_tables : list [HybridizationTable ] | None = None ,
1210+ array_data_files : list [ArrayData ] | None = None ,
12081211 config : ProblemConfig = None ,
12091212 ):
12101213 from ..v2 .lint import default_validation_tasks
@@ -1221,6 +1224,11 @@ def __init__(
12211224 self .measurement_tables = measurement_tables or [MeasurementTable ()]
12221225 self .mapping_tables = mapping_tables or [MappingTable ()]
12231226 self .parameter_tables = parameter_tables or [ParameterTable ()]
1227+ self .neural_networks = neural_networks or []
1228+ self .hybridization_tables = hybridization_tables or [
1229+ HybridizationTable ()
1230+ ]
1231+ self .array_data_files = array_data_files or []
12241232
12251233 def __repr__ (self ):
12261234 return f"<{ self .__class__ .__name__ } id={ self .id !r} >"
@@ -1393,6 +1401,59 @@ def from_yaml(
13931401 else None
13941402 )
13951403
1404+ # sciml extension
1405+ if config .extensions and config .extensions [C .SCIML ]:
1406+ try :
1407+ from petab_sciml import (
1408+ ArrayDataStandard ,
1409+ NNModel ,
1410+ NNModelStandard ,
1411+ )
1412+ except ImportError as e :
1413+ raise ImportError (
1414+ "To generate a PEtab SciML problem, (petab_sciml) must be"
1415+ "installed."
1416+ ) from e
1417+
1418+ # Neural network classes are constructed via pytorch for now to get the
1419+ # proper inputs
1420+ neural_networks = (
1421+ [
1422+ NNModel .from_pytorch_module (
1423+ NNModelStandard .load_data (
1424+ _generate_path (
1425+ file_path = nn_config .location ,
1426+ base_path = base_path ,
1427+ )
1428+ ).to_pytorch_module (),
1429+ nn_model_id = nn_id ,
1430+ )
1431+ for nn_id , nn_config in (
1432+ config .extensions [C .SCIML ].neural_nets or {}
1433+ ).items ()
1434+ ]
1435+ if config .extensions and config .extensions [C .SCIML ]
1436+ else None
1437+ )
1438+
1439+ hybridization_tables = (
1440+ [
1441+ HybridizationTable .from_tsv (f , base_path )
1442+ for f in config .extensions [C .SCIML ].hybridization_files
1443+ ]
1444+ if config .extensions and config .extensions [C .SCIML ]
1445+ else None
1446+ )
1447+
1448+ array_data_files = (
1449+ [
1450+ ArrayDataStandard .load_data (_generate_path (f , base_path ))
1451+ for f in config .extensions [C .SCIML ].array_files
1452+ ]
1453+ if config .extensions and config .extensions [C .SCIML ]
1454+ else None
1455+ )
1456+
13961457 return Problem (
13971458 config = config ,
13981459 models = models ,
@@ -1402,6 +1463,9 @@ def from_yaml(
14021463 measurement_tables = measurement_tables ,
14031464 parameter_tables = parameter_tables ,
14041465 mapping_tables = mapping_tables ,
1466+ neural_networks = neural_networks ,
1467+ hybridization_tables = hybridization_tables ,
1468+ array_data_files = array_data_files ,
14051469 )
14061470
14071471 @staticmethod
@@ -1708,6 +1772,15 @@ def id(self, value: str):
17081772 self .config = ProblemConfig (format_version = "2.0.0" )
17091773 self .config .id = value
17101774
1775+ @property
1776+ def hybridizations (self ) -> list [Hybridization ]:
1777+ """List of hybridizations in the hybridization table(s)."""
1778+ return list (
1779+ chain .from_iterable (
1780+ ht .hybridizations for ht in self .hybridization_tables
1781+ )
1782+ )
1783+
17111784 def get_optimization_parameters (self ) -> list [str ]:
17121785 """
17131786 Get the list of optimization parameter IDs from parameter table.
0 commit comments