diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 687e36f39..4368953b6 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -174,10 +174,10 @@ Optimizers and Schedulers .. toctree:: :titlesonly: - Optimizer - Scheduler - TorchOptimizer - TorchScheduler + Optimizer Interface + Scheduler Interface + Torch Optimizer + Torch Scheduler Adaptive Functions @@ -297,11 +297,13 @@ Callbacks Switch Optimizer Switch Scheduler - Normalizer Data - PINA Progress Bar - Metric Tracker Refinement Interface + Base Refinement R3 Refinement + Data Normalizer + Metric Tracker + PINA Progress Bar + Losses --------- diff --git a/docs/source/_rst/callback/processing/data_normalizer.rst b/docs/source/_rst/callback/processing/data_normalizer.rst new file mode 100644 index 000000000..358d2f472 --- /dev/null +++ b/docs/source/_rst/callback/processing/data_normalizer.rst @@ -0,0 +1,9 @@ +Data Normalizer +======================= +.. currentmodule:: pina.callback.processing.data_normalizer + +.. automodule:: pina._src.callback.processing.data_normalizer + +.. autoclass:: pina._src.callback.processing.data_normalizer.DataNormalizer + :members: + :show-inheritance: diff --git a/docs/source/_rst/callback/processing/metric_tracker.rst b/docs/source/_rst/callback/processing/metric_tracker.rst index 202522831..22d7cc229 100644 --- a/docs/source/_rst/callback/processing/metric_tracker.rst +++ b/docs/source/_rst/callback/processing/metric_tracker.rst @@ -1,8 +1,10 @@ Metric Tracker ================== .. currentmodule:: pina.callback.processing.metric_tracker + .. automodule:: pina._src.callback.processing.metric_tracker - :show-inheritance: -.. autoclass:: MetricTracker + +.. autoclass:: pina._src.callback.processing.metric_tracker.MetricTracker :members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/callback/processing/normalizer_data_callback.rst b/docs/source/_rst/callback/processing/normalizer_data_callback.rst deleted file mode 100644 index 31fd769c8..000000000 --- a/docs/source/_rst/callback/processing/normalizer_data_callback.rst +++ /dev/null @@ -1,9 +0,0 @@ -Normalizer Data -======================= - -.. currentmodule:: pina.callback.processing.normalizer_data_callback -.. automodule:: pina._src.callback.processing.normalizer_data_callback - :show-inheritance: -.. autoclass:: NormalizerDataCallback - :members: - :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/processing/pina_progress_bar.rst b/docs/source/_rst/callback/processing/pina_progress_bar.rst index da3a878ba..9c64678eb 100644 --- a/docs/source/_rst/callback/processing/pina_progress_bar.rst +++ b/docs/source/_rst/callback/processing/pina_progress_bar.rst @@ -1,8 +1,9 @@ PINA Progress Bar ================== .. currentmodule:: pina.callback.processing.pina_progress_bar + .. automodule:: pina._src.callback.processing.pina_progress_bar - :show-inheritance: -.. autoclass:: PINAProgressBar + +.. autoclass:: pina._src.callback.processing.pina_progress_bar.PINAProgressBar :members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/docs/source/_rst/callback/refinement/base_refinement.rst b/docs/source/_rst/callback/refinement/base_refinement.rst new file mode 100644 index 000000000..5f8eaf218 --- /dev/null +++ b/docs/source/_rst/callback/refinement/base_refinement.rst @@ -0,0 +1,7 @@ +Base Refinement +======================= + +.. currentmodule:: pina.callback.refinement.base_refinement +.. autoclass:: pina._src.callback.refinement.base_refinement.BaseRefinement + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/refinement/r3_refinement.rst b/docs/source/_rst/callback/refinement/r3_refinement.rst index 5f0da6ea6..0d787c840 100644 --- a/docs/source/_rst/callback/refinement/r3_refinement.rst +++ b/docs/source/_rst/callback/refinement/r3_refinement.rst @@ -1,7 +1,7 @@ -Refinments callbacks +R3 Refinement ======================= -.. currentmodule:: pina.callback +.. currentmodule:: pina.callback.refinement.r3_refinement .. autoclass:: pina._src.callback.refinement.r3_refinement.R3Refinement :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/refinement/refinement_interface.rst b/docs/source/_rst/callback/refinement/refinement_interface.rst index d1de6429b..1af845800 100644 --- a/docs/source/_rst/callback/refinement/refinement_interface.rst +++ b/docs/source/_rst/callback/refinement/refinement_interface.rst @@ -1,7 +1,7 @@ Refinement Interface ======================= -.. currentmodule:: pina.callback +.. currentmodule:: pina.callback.refinement.refinement_interface .. autoclass:: pina._src.callback.refinement.refinement_interface.RefinementInterface :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/optim/optimizer_interface.rst b/docs/source/_rst/optim/optimizer_interface.rst index afd62f6a0..23a933bae 100644 --- a/docs/source/_rst/optim/optimizer_interface.rst +++ b/docs/source/_rst/optim/optimizer_interface.rst @@ -1,7 +1,7 @@ -Optimizer -============ +Optimizer Interface +===================== .. currentmodule:: pina.optim.optimizer_interface -.. autoclass:: pina._src.optim.optimizer_interface.Optimizer +.. autoclass:: pina._src.optim.optimizer_interface.OptimizerInterface :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/optim/scheduler_interface.rst b/docs/source/_rst/optim/scheduler_interface.rst index 0795c34e3..03b3e83f7 100644 --- a/docs/source/_rst/optim/scheduler_interface.rst +++ b/docs/source/_rst/optim/scheduler_interface.rst @@ -1,7 +1,7 @@ -Scheduler -============= +Scheduler Interface +===================== .. currentmodule:: pina.optim.scheduler_interface -.. autoclass:: pina._src.optim.scheduler_interface.Scheduler +.. autoclass:: pina._src.optim.scheduler_interface.SchedulerInterface :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/optim/torch_optimizer.rst b/docs/source/_rst/optim/torch_optimizer.rst index 67ab59164..54bfe9a3a 100644 --- a/docs/source/_rst/optim/torch_optimizer.rst +++ b/docs/source/_rst/optim/torch_optimizer.rst @@ -1,4 +1,4 @@ -TorchOptimizer +Torch Optimizer =============== .. currentmodule:: pina.optim.torch_optimizer diff --git a/docs/source/_rst/optim/torch_scheduler.rst b/docs/source/_rst/optim/torch_scheduler.rst index 272ba631f..59260533e 100644 --- a/docs/source/_rst/optim/torch_scheduler.rst +++ b/docs/source/_rst/optim/torch_scheduler.rst @@ -1,4 +1,4 @@ -TorchScheduler +Torch Scheduler =============== .. currentmodule:: pina.optim.torch_scheduler diff --git a/pina/_src/callback/optim/switch_optimizer.py b/pina/_src/callback/optim/switch_optimizer.py index 4f6f0be09..36561fa28 100644 --- a/pina/_src/callback/optim/switch_optimizer.py +++ b/pina/_src/callback/optim/switch_optimizer.py @@ -1,27 +1,36 @@ """Module for the SwitchOptimizer callback.""" from lightning.pytorch.callbacks import Callback -from pina._src.optim.torch_optimizer import TorchOptimizer -from pina._src.core.utils import check_consistency +from pina._src.optim.optimizer_interface import OptimizerInterface +from pina._src.core.utils import check_consistency, check_positive_integer class SwitchOptimizer(Callback): """ - PINA Implementation of a Lightning Callback to switch optimizer during - training. + Lightning callback for dynamically replacing optimizers during training. + + This callback enables switching to one or more new optimizers at a specified + epoch without restarting the training loop. It is particularly useful for + staged optimization strategies (e.g., coarse-to-fine training or optimizer + warm-up phases), where different optimizers are applied sequentially. + + At the target epoch, the provided optimizers are hooked to the model + parameters and replace the current optimizers in both the PINA solver and + the Lightning trainer strategy. """ def __init__(self, new_optimizers, epoch_switch): """ - This callback allows switching between different optimizers during - training, enabling the exploration of multiple optimization strategies - without interrupting the training process. + Initialization of the :class:`SwitchOptimizer` class. :param new_optimizers: The model optimizers to switch to. Can be a single :class:`torch.optim.Optimizer` instance or a list of them for multiple model solver. - :type new_optimizers: pina.optim.TorchOptimizer | list + :type new_optimizers: pina.optim.OptimizerInterface | list :param int epoch_switch: The epoch at which the optimizer switch occurs. + :raises AssertionError: If ``epoch_switch`` is not a positive integer. + :raises ValueError: If any of the provided optimizers are not instances + of :class:`pina.optim.OptimizerInterface`. Example: >>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01) @@ -31,19 +40,14 @@ def __init__(self, new_optimizers, epoch_switch): """ super().__init__() - # Check if epoch_switch is greater than 1 - if epoch_switch < 1: - raise ValueError("epoch_switch must be greater than one.") + # Check consistency + check_positive_integer(epoch_switch, strict=True) + check_consistency(new_optimizers, OptimizerInterface) # If new_optimizers is not a list, convert it to a list if not isinstance(new_optimizers, list): new_optimizers = [new_optimizers] - # Check consistency - check_consistency(epoch_switch, int) - for optimizer in new_optimizers: - check_consistency(optimizer, TorchOptimizer) - # Store the new optimizers and epoch switch self._new_optimizers = new_optimizers self._epoch_switch = epoch_switch @@ -52,9 +56,9 @@ def on_train_epoch_start(self, trainer, __): """ Switch the optimizer at the start of the specified training epoch. - :param lightning.pytorch.Trainer trainer: The trainer object managing - the training process. - :param _: Placeholder argument (not used). + :param Trainer trainer: The trainer object managing the training + process. + :param __: Placeholder argument, not used. """ # Check if the current epoch matches the switch epoch if trainer.current_epoch == self._epoch_switch: diff --git a/pina/_src/callback/optim/switch_scheduler.py b/pina/_src/callback/optim/switch_scheduler.py index 3a9215f17..61284fb50 100644 --- a/pina/_src/callback/optim/switch_scheduler.py +++ b/pina/_src/callback/optim/switch_scheduler.py @@ -1,30 +1,31 @@ """Module for the SwitchScheduler callback.""" from lightning.pytorch.callbacks import Callback -from pina._src.optim.torch_scheduler import TorchScheduler +from pina._src.optim.scheduler_interface import SchedulerInterface from pina._src.core.utils import check_consistency, check_positive_integer class SwitchScheduler(Callback): """ - Callback to switch scheduler during training. + Lightning callback for dynamically replacing schedulers during training. + + This callback enables switching to new scheduler(s) at a specified epoch + without interrupting the training loop. It is useful for staged training + strategies where different learning rate policies are applied sequentially. """ def __init__(self, new_schedulers, epoch_switch): """ - This callback allows switching between different schedulers during - training, enabling the exploration of multiple optimization strategies - without interrupting the training process. + Initialization of the :class:`SwitchScheduler` class. :param new_schedulers: The scheduler or list of schedulers to switch to. Use a single scheduler for single-model solvers, or a list of schedulers when working with multiple models. - :type new_schedulers: pina.optim.TorchScheduler | - list[pina.optim.TorchScheduler] + :type new_schedulers: SchedulerInterface | list[SchedulerInterface] :param int epoch_switch: The epoch at which the scheduler switch occurs. - :raises AssertionError: If epoch_switch is less than 1. - :raises ValueError: If each scheduler in ``new_schedulers`` is not an - instance of :class:`pina.optim.TorchScheduler`. + :raises AssertionError: If ``epoch_switch`` is not a positive integer. + :raises ValueError: If any of the provided schedulers are not instances + of :class:`pina.optim.SchedulerInterface`. Example: >>> scheduler = TorchScheduler( @@ -36,17 +37,14 @@ def __init__(self, new_schedulers, epoch_switch): """ super().__init__() - # Check if epoch_switch is greater than 1 - check_positive_integer(epoch_switch - 1, strict=True) + # Check consistency + check_positive_integer(epoch_switch, strict=True) + check_consistency(new_schedulers, SchedulerInterface) # If new_schedulers is not a list, convert it to a list if not isinstance(new_schedulers, list): new_schedulers = [new_schedulers] - # Check consistency - for scheduler in new_schedulers: - check_consistency(scheduler, TorchScheduler) - # Store the new schedulers and epoch switch self._new_schedulers = new_schedulers self._epoch_switch = epoch_switch @@ -55,9 +53,9 @@ def on_train_epoch_start(self, trainer, __): """ Switch the scheduler at the start of the specified training epoch. - :param lightning.pytorch.Trainer trainer: The trainer object managing + :param Trainer trainer: The trainer object managing the training process. - :param __: Placeholder argument (not used). + :param __: Placeholder argument, not used. """ # Check if the current epoch matches the switch epoch if trainer.current_epoch == self._epoch_switch: diff --git a/pina/_src/callback/processing/data_normalizer.py b/pina/_src/callback/processing/data_normalizer.py new file mode 100644 index 000000000..23512813c --- /dev/null +++ b/pina/_src/callback/processing/data_normalizer.py @@ -0,0 +1,206 @@ +"""Module for the Data Normalizer callback.""" + +from typing import Callable +import torch +from lightning.pytorch import Callback +from pina._src.core.utils import check_consistency +from pina._src.core.label_tensor import LabelTensor +from pina._src.condition.condition import InputTargetCondition + + +class DataNormalizer(Callback): + r""" + Callback for dataset normalization on input-target conditions. + + This callback computes and applies a normalization transform to either + input or target tensors within a dataset. The transformation is defined as: + + .. math:: + + x_{\text{norm}} = \frac{x - \mu}{\sigma}, + + where :math:`\mu` and :math:`\sigma` are computed using the provided + ``shift_fn`` and ``scale_fn`` functions, respectively. Normalization + parameters are estimated from the training dataset and then applied in-place + to the selected datasets depending on the chosen stage. + + .. note:: + + This callback ignores all conditions that are not instances of + :class:`~pina.condition.InputTargetCondition`. + + :Example: + + >>> DataNormalizer( + ... scale_fn=torch.std, + ... shift_fn=torch.mean, + ... stage="all", + ... apply_to="input", + ... ) + """ + + # Define valid options for stage and apply_to parameters + _VALID_STAGES = {"train", "validate", "test", "all"} + _VALID_APPLY_TO = {"input", "target"} + + def __init__( + self, + scale_fn=torch.std, + shift_fn=torch.mean, + stage="all", + apply_to="input", + ): + """ + Initialization of the :class:`DataNormalizer` class. + + :param Callable scale_fn: The function used to compute the scaling + factor. Default is ``torch.std``. + :param Callable shift_fn: The function used to compute the shifting + factor. Default is ``torch.mean``. + :param str stage: The stage during which normalization is applied. + Available options are ``"train"``, ``"validate"``, ``"test"``, and + ``"all"``. Default is ``"all"``. + :param str apply_to: Specifies whether normalization is applied to + ``"input"`` or ``"target"`` tensors. Default is ``"input"``. + :raises ValueError: If ``scale_fn`` is not Callable. + :raises ValueError: If ``shift_fn`` is not Callable. + :raises ValueError: If ``stage`` is invalid. + :raises ValueError: If ``apply_to`` is invalid. + """ + super().__init__() + + # Check consistency + check_consistency(scale_fn, Callable) + check_consistency(shift_fn, Callable) + check_consistency(stage, str) + check_consistency(apply_to, str) + + # Validate stage parameter + if stage not in self._VALID_STAGES: + raise ValueError( + "Invalid value for 'stage'. Available options are " + f"{self._VALID_STAGES}. Got {stage}." + ) + + # Validate apply_to parameter + if apply_to not in self._VALID_APPLY_TO: + raise ValueError( + "Invalid value for 'apply_to'. Available options are " + f"{self._VALID_APPLY_TO}. Got {apply_to}." + ) + + # Initialize attributes + self.scale_fn = scale_fn + self.shift_fn = shift_fn + self.stage = stage + self.apply_to = apply_to + self._normalizer = {} + self._normalized_conditions = set() + + def setup(self, trainer, pl_module, stage): + """ + Compute and apply normalization during the setup phase. + + :param Trainer trainer: The trainer instance managing the execution. + :param SolverInterface pl_module: The solver module being executed. + :param str stage: Current execution stage. + :raises NotImplementedError: If the dataset is graph-based and + therefore unsupported. + """ + # Check if any condition contains graph-based data + if any( + hasattr(ds.condition.data, "graph_key") + for ds in trainer.datamodule.train_datasets.values() + ): + raise NotImplementedError( + "DataNormalizer is not compatible with graph-based datasets." + ) + + # Extract input-target conditions + conditions_to_normalize = [ + name + for name, cond in pl_module.problem.conditions.items() + if isinstance(cond, InputTargetCondition) + ] + + # Extract the dictionary of all datasets + dataset = trainer.datamodule.train_datasets + + # Compute scale and shift parameters if not already computed + if not self.normalizer: + + # Iterate over conditions and compute normalization parameters + for cond in conditions_to_normalize: + pts = self._get_data(dataset, cond) + shift = self.shift_fn(pts) + scale = self.scale_fn(pts) + + self._normalizer[cond] = { + "shift": shift, + "scale": scale, + } + + # Apply normalization to training datasets + if stage == "fit" and self.stage in ["train", "all"]: + self.normalize_dataset(trainer.datamodule.train_datasets) + + if stage == "fit" and self.stage in ["validate", "all"]: + self.normalize_dataset(trainer.datamodule.val_datasets) + + if stage == "test" and self.stage in ["test", "all"]: + self.normalize_dataset(trainer.datamodule.test_datasets) + + return super().setup(trainer, pl_module, stage) + + def normalize_dataset(self, dataset): + """ + Apply normalization to all datasets in-place. + + Each condition is updated using precomputed normalization parameters. + The transformation preserves tensor types. + + :param dict dataset: The mapping between condition names and their + associated dataset subsets. + """ + # Iterate over conditions and apply normalization + for cond, norm_params in self.normalizer.items(): + if cond in self._normalized_conditions: + continue + + # Extract the points to normalize and the normalization parameters + data_container = getattr(dataset[cond].condition, self.apply_to) + points = data_container.data + scale = norm_params["scale"] + shift = norm_params["shift"] + + # Apply normalization + scaled_pts = (points - shift) / scale + if isinstance(data_container, LabelTensor): + scaled_pts = LabelTensor(scaled_pts, data_container.labels) + + # Update the dataset in-place + data_container.data = scaled_pts + self._normalized_conditions.add(cond) + + def _get_data(self, dataset, cond): + """ + Extract the selected data field from the dataset for a given condition. + + :param dict dataset: The mapping between condition names and their + associated dataset subsets. + :param str cond: The condition name. + :return: The selected input or target data. + :rtype: torch.Tensor + """ + return getattr(dataset[cond].condition, self.apply_to).data + + @property + def normalizer(self): + """ + The dictionary mapping each condition to its corresponding ``shift`` and + ``scale`` values. + + :return: The dictionary of normalization parameters. + :rtype: dict + """ + return self._normalizer diff --git a/pina/_src/callback/processing/metric_tracker.py b/pina/_src/callback/processing/metric_tracker.py index 9b1dc9d4a..68e6d35e0 100644 --- a/pina/_src/callback/processing/metric_tracker.py +++ b/pina/_src/callback/processing/metric_tracker.py @@ -3,52 +3,78 @@ import copy import torch from lightning.pytorch.callbacks import Callback +from pina._src.core.utils import check_consistency class MetricTracker(Callback): """ - Lightning Callback for Metric Tracking. + Callback for collecting selected metrics logged during training. """ def __init__(self, metrics_to_track=None): """ - Tracks specified metrics during training. + Initialization of the :class:`MetricTracker` class. - :param metrics_to_track: List of metrics to track. - Defaults to train loss. - :type metrics_to_track: list[str], optional + :param metrics_to_track: The names of the metrics to collect. If + ``None``, defaults to ``["train_loss"]`` when no batch size is + available, otherwise to ``["train_loss_epoch"]``. Default is + ``None``. + :type metrics_to_track: str | list[str] + :raises ValueError: If any of the provided metric names are not strings. """ super().__init__() - self._collection = [] - # Default to tracking 'train_loss' if not specified + + # Check consistency + if metrics_to_track is not None: + check_consistency(metrics_to_track, str) + + # Convert to list if a single string is provided + if isinstance(metrics_to_track, str): + metrics_to_track = [metrics_to_track] + + # Initialize the collection list and store the metrics to track self.metrics_to_track = metrics_to_track + self._collection = [] def setup(self, trainer, pl_module, stage): """ - Called when fit, validate, test, predict, or tune begins. + Configure the metrics to track before execution starts. - :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. - :param SolverInterface pl_module: A - :class:`~pina.solver.solver.SolverInterface` instance. - :param str stage: Either 'fit', 'test' or 'predict'. + When a batch size is provided (i.e. ``trainer.batch_size`` is not + ``None``), metric names are expanded to match Lightning's logging + convention: for each metric ``m``, both ``m_step`` and ``m_epoch`` are + tracked. For example, ``"train_loss"`` becomes + ``["train_loss_step", "train_loss_epoch"]``. + + :param Trainer trainer: The trainer instance managing the execution. + :param SolverInterface pl_module: The solver module being executed. + :param str stage: Current execution stage. """ - if self.metrics_to_track is None and trainer.batch_size is None: + # Set default metrics to train_loss if no batch size is available + if self.metrics_to_track is None: self.metrics_to_track = ["train_loss"] - elif self.metrics_to_track is None: - self.metrics_to_track = ["train_loss_epoch"] + + # If a batch size is provided, expand metric names to match convention + if trainer.batch_size is not None: + self.metrics_to_track = [ + f"{metric}_{suffix}" + for metric in self.metrics_to_track + for suffix in ("step", "epoch") + ] + return super().setup(trainer, pl_module, stage) - def on_train_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, __): """ - Collect and track metrics at the end of each training epoch. + Store the selected logged metrics at the end of each training epoch. - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param pl_module: The model being trained (not used here). + :param Trainer trainer: The trainer instance managing the execution. + :param __: Placeholder argument, not used. """ - # Track metrics after the first epoch onwards + # Only collect metrics after the first epoch to ensure they are logged if trainer.current_epoch > 0: - # Append only the tracked metrics to avoid unnecessary data + + # Collect the metrics that are being tracked tracked_metrics = { k: v for k, v in trainer.logged_metrics.items() @@ -59,20 +85,21 @@ def on_train_epoch_end(self, trainer, pl_module): @property def metrics(self): """ - Aggregate collected metrics over all epochs. + Return the collected metrics stacked over the tracked epochs. - :return: A dictionary containing aggregated metric values. - :rtype: dict + :return: The dictionary mapping each metric name to a tensor containing + its values across epochs. Returns an empty dictionary if no metrics + have been collected. + :rtype: dict[str, torch.Tensor] """ if not self._collection: return {} - # Get intersection of keys across all collected dictionaries + # Identify the common keys across all collected metric dictionaries common_keys = set(self._collection[0]).intersection( *self._collection[1:] ) - # Stack the metric values for common keys and return return { k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys diff --git a/pina/_src/callback/processing/normalizer_data_callback.py b/pina/_src/callback/processing/normalizer_data_callback.py deleted file mode 100644 index 2524f5765..000000000 --- a/pina/_src/callback/processing/normalizer_data_callback.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Module for the Normalizer callback.""" - -import torch -from lightning.pytorch import Callback -from pina._src.core.label_tensor import LabelTensor -from pina._src.core.utils import check_consistency, is_function -from pina._src.condition.condition import InputTargetCondition -from pina._src.data.dataset import PinaGraphDataset - - -class NormalizerDataCallback(Callback): - r""" - A Callback used to normalize the dataset inputs or targets according to - user-provided scale and shift functions. - - The transformation is applied as: - - .. math:: - - x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}} - - :Example: - - >>> NormalizerDataCallback() - >>> NormalizerDataCallback( - ... scale_fn: torch.std, - ... shift_fn: torch.mean, - ... stage: "all", - ... apply_to: "input", - ... ) - """ - - def __init__( - self, - scale_fn=torch.std, - shift_fn=torch.mean, - stage="all", - apply_to="input", - ): - """ - Initialization of the :class:`NormalizerDataCallback` class. - - :param Callable scale_fn: The function to compute the scaling factor. - Default is ``torch.std``. - :param Callable shift_fn: The function to compute the shifting factor. - Default is ``torch.mean``. - :param str stage: The stage in which normalization is applied. - Accepted values are "train", "validate", "test", or "all". - Default is ``"all"``. - :param str apply_to: Whether to normalize "input" or "target" data. - Default is ``"input"``. - :raises ValueError: If ``scale_fn`` is not callable. - :raises ValueError: If ``shift_fn`` is not callable. - """ - super().__init__() - - # Validate parameters - self.apply_to = self._validate_apply_to(apply_to) - self.stage = self._validate_stage(stage) - - # Validate functions - if not is_function(scale_fn): - raise ValueError(f"scale_fn must be Callable, got {scale_fn}") - if not is_function(shift_fn): - raise ValueError(f"shift_fn must be Callable, got {shift_fn}") - self.scale_fn = scale_fn - self.shift_fn = shift_fn - - # Initialize normalizer dictionary - self._normalizer = {} - - def _validate_apply_to(self, apply_to): - """ - Validate the ``apply_to`` parameter. - - :param str apply_to: The candidate value for the ``apply_to`` parameter. - :raises ValueError: If ``apply_to`` is neither "input" nor "target". - :return: The validated ``apply_to`` value. - :rtype: str - """ - check_consistency(apply_to, str) - if apply_to not in {"input", "target"}: - raise ValueError( - f"apply_to must be either 'input' or 'target', got {apply_to}" - ) - - return apply_to - - def _validate_stage(self, stage): - """ - Validate the ``stage`` parameter. - - :param str stage: The candidate value for the ``stage`` parameter. - :raises ValueError: If ``stage`` is not one of "train", "validate", - "test", or "all". - :return: The validated ``stage`` value. - :rtype: str - """ - check_consistency(stage, str) - if stage not in {"train", "validate", "test", "all"}: - raise ValueError( - "stage must be one of 'train', 'validate', 'test', or 'all'," - f" got {stage}" - ) - - return stage - - def setup(self, trainer, pl_module, stage): - """ - Apply normalization during setup. - - :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. - :param SolverInterface pl_module: A - :class:`~pina.solver.solver.SolverInterface` instance. - :param str stage: The current stage. - :raises RuntimeError: If the training dataset is not available when - computing normalization parameters. - :return: The result of the parent setup. - :rtype: Any - - :raises NotImplementedError: If the dataset is graph-based. - """ - - # Ensure datsets are not graph-based - if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): - raise NotImplementedError( - "NormalizerDataCallback is not compatible with " - "graph-based datasets." - ) - - # Extract conditions - conditions_to_normalize = [ - name - for name, cond in pl_module.problem.conditions.items() - if isinstance(cond, InputTargetCondition) - ] - - # Compute scale and shift parameters - if not self.normalizer: - if not trainer.datamodule.train_dataset: - raise RuntimeError( - "Training dataset is not available. Cannot compute " - "normalization parameters." - ) - self._compute_scale_shift( - conditions_to_normalize, trainer.datamodule.train_dataset - ) - - # Apply normalization based on the specified stage - if stage == "fit" and self.stage in ["train", "all"]: - self.normalize_dataset(trainer.datamodule.train_dataset) - if stage == "fit" and self.stage in ["validate", "all"]: - self.normalize_dataset(trainer.datamodule.val_dataset) - if stage == "test" and self.stage in ["test", "all"]: - self.normalize_dataset(trainer.datamodule.test_dataset) - - return super().setup(trainer, pl_module, stage) - - def _compute_scale_shift(self, conditions, dataset): - """ - Compute scale and shift parameters for each condition in the dataset. - - :param list conditions: The list of condition names. - :param dataset: The `~pina.data.dataset.PinaDataset` dataset. - """ - for cond in conditions: - if cond in dataset.conditions_dict: - data = dataset.conditions_dict[cond][self.apply_to] - shift = self.shift_fn(data) - scale = self.scale_fn(data) - self._normalizer[cond] = { - "shift": shift, - "scale": scale, - } - - @staticmethod - def _norm_fn(value, scale, shift): - """ - Normalize a value according to the scale and shift parameters. - - :param value: The input tensor to normalize. - :type value: torch.Tensor | LabelTensor - :param float scale: The scaling factor. - :param float shift: The shifting factor. - :return: The normalized tensor. - :rtype: torch.Tensor | LabelTensor - """ - scaled_value = (value - shift) / scale - if isinstance(value, LabelTensor): - scaled_value = LabelTensor(scaled_value, value.labels) - - return scaled_value - - def normalize_dataset(self, dataset): - """ - Apply in-place normalization to the dataset. - - :param PinaDataset dataset: The dataset to be normalized. - """ - # Initialize update dictionary - update_dataset_dict = {} - - # Iterate over conditions and apply normalization - for cond, norm_params in self.normalizer.items(): - points = dataset.conditions_dict[cond][self.apply_to] - scale = norm_params["scale"] - shift = norm_params["shift"] - normalized_points = self._norm_fn(points, scale, shift) - update_dataset_dict[cond] = { - self.apply_to: ( - LabelTensor(normalized_points, points.labels) - if isinstance(points, LabelTensor) - else normalized_points - ) - } - - # Update the dataset in-place - dataset.update_data(update_dataset_dict) - - @property - def normalizer(self): - """ - Get the dictionary of normalization parameters. - - :return: The dictionary of normalization parameters. - :rtype: dict - """ - return self._normalizer diff --git a/pina/_src/callback/processing/pina_progress_bar.py b/pina/_src/callback/processing/pina_progress_bar.py index 90c34f8cc..7a7c2a905 100644 --- a/pina/_src/callback/processing/pina_progress_bar.py +++ b/pina/_src/callback/processing/pina_progress_bar.py @@ -9,9 +9,19 @@ class PINAProgressBar(TQDMProgressBar): """ - PINA Implementation of a Lightning Callback for enriching the progress bar. + Custom progress bar callback for PINA training workflows. + + This callback extends the default Lightning progress bar by filtering the + displayed metrics. + + Metrics can refer either to condition-specific losses, identified by the + names assigned to the problem conditions, or to global losses. Global losses + are selected using ``"train"``, ``"val"``, or ``"test"``, and are internally + expanded to the corresponding logged loss metrics. """ + GLOBAL_LOSS_KEYS = ("train", "val", "test") + BAR_FORMAT = ( "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, " "{rate_noinv_fmt}{postfix}]" @@ -19,81 +29,88 @@ class PINAProgressBar(TQDMProgressBar): def __init__(self, metrics="val", **kwargs): """ - This class enables the display of only relevant metrics during training. + Initialization of the :class:`PINAProgressBar`. - :param metrics: Logged metrics to be shown during the training. - Must be a subset of the conditions keys defined in - :obj:`pina.condition.Condition`. + :param metrics: The names of the metrics to be shown in the progress + bar. Each entry can be either a key of a condition defined in the + problem or one of the global loss keys: ``"train"``, ``"val"``, or + ``"test"``. These global keys are internally expanded to the + corresponding logged loss names. Default is ``"val"``. :type metrics: str | list(str) | tuple(str) - - :Keyword Arguments: - The additional keyword arguments specify the progress bar and can be - choosen from the `pytorch-lightning TQDMProgressBar API - `_ - - Example: - >>> pbar = PINAProgressBar(['mean']) - >>> # ... Perform training ... - >>> trainer = Trainer(solver, callbacks=[pbar]) + :param dict kwargs: Additional keyword arguments passed to + :class:`lightning.pytorch.callbacks.TQDMProgressBar`. + :raises TypeError: If ``metrics`` contains non-string elements. """ super().__init__(**kwargs) - # check consistency - if not isinstance(metrics, (list, tuple)): - metrics = [metrics] + + # Check consistency check_consistency(metrics, str) - self._sorted_metrics = metrics - def get_metrics(self, trainer, pl_module): - r"""Combine progress bar metrics collected from the trainer with - standard metrics from get_standard_metrics. - Override this method to customize the items shown in the progress bar. - The progress bar metrics are sorted according to ``metrics``. + # Convert to list if a single string is provided + if isinstance(metrics, str): + metrics = [metrics] - Here is an example of how to override the defaults: + # Store the sorted metrics for later use in get_metrics + self._sorted_metrics = sorted(metrics) - .. code-block:: python + def get_metrics(self, trainer, __): + """ + Retrieve and filter metrics to be displayed in the progress bar. + + This method combines standard Lightning metrics with user-selected + progress bar metrics, retaining only the metrics specified at + initialization. - def get_metrics(self, trainer, model): - # don't show the version number - items = super().get_metrics(trainer, model) - items.pop("v_num", None) - return items + :param Trainer trainer: The trainer managing the training loop. + :param __: Placeholder argument, not used. + :return: Dictionary containing the metrics to display. + :rtype: dict - :return: Dictionary with the items to be displayed in the progress bar. - :rtype: tuple(dict) + .. note:: + This method overrides the default Lightning behavior. It can be + further customized by subclassing. """ + # Retrieve standard metrics and user-selected progress bar metrics standard_metrics = get_standard_metrics(trainer) - pbar_metrics = trainer.progress_bar_metrics - if pbar_metrics: - pbar_metrics = { - key: pbar_metrics[key] - for key in pbar_metrics + progress_bar_metrics = trainer.progress_bar_metrics + + # Filter progress bar metrics to include only specified keys + if progress_bar_metrics: + progress_bar_metrics = { + key: progress_bar_metrics[key] + for key in progress_bar_metrics if key in self._sorted_metrics } - return {**standard_metrics, **pbar_metrics} + + return {**standard_metrics, **progress_bar_metrics} def setup(self, trainer, pl_module, stage): """ - Check that the initialized metrics are available and correctly logged. + Configure the metrics to track before execution starts. - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param pl_module: Placeholder argument. + The requested metrics must be either names assigned to problem + conditions or global loss keys. The accepted global loss keys are + ``"train"``, ``"val"``, and ``"test"``. + + :param Trainer trainer: The trainer instance managing the execution. + :param SolverInterface pl_module: The solver module being executed. + :param str stage: Current execution stage. + :raises KeyError: If a metric key is neither a condition key nor one of + ``"train"``, ``"val"``, or ``"test"``. """ - # Check if all keys in sort_keys are present in the dictionary + # Get the condition keys from the problem + condition_keys = trainer.solver.problem.conditions.keys() for key in self._sorted_metrics: - if ( - key not in trainer.solver.problem.conditions.keys() - and key != "train" - and key != "val" - ): - raise KeyError(f"Key '{key}' is not present in the dictionary") - # add the loss pedix - if trainer.batch_size is not None: - pedix = "_loss_epoch" - else: - pedix = "_loss" + if key not in condition_keys and key not in self.GLOBAL_LOSS_KEYS: + raise KeyError( + f"Key '{key}' is not a valid metric. It must be either a " + f"problem condition key or one of {self.GLOBAL_LOSS_KEYS}." + ) + + # Add the appropriate suffix to the metric names based on batch size + suffix = "_loss_epoch" if trainer.batch_size is not None else "_loss" self._sorted_metrics = [ - metric + pedix for metric in self._sorted_metrics + metric + suffix for metric in self._sorted_metrics ] + return super().setup(trainer, pl_module, stage) diff --git a/pina/_src/callback/refinement/base_refinement.py b/pina/_src/callback/refinement/base_refinement.py new file mode 100644 index 000000000..b855d0b28 --- /dev/null +++ b/pina/_src/callback/refinement/base_refinement.py @@ -0,0 +1,155 @@ +"""Module for the Base Refinement class.""" + +from lightning.pytorch import Callback +from pina._src.core.utils import check_consistency, check_positive_integer +from pina._src.callback.refinement.refinement_interface import ( + RefinementInterface, +) +from pina._src.solver.physics_informed_solver.pinn_interface import ( + PINNInterface, +) + + +class BaseRefinement(Callback, RefinementInterface): + """ + Base class for all refinement strategies, implementing common functionality. + + A refinement strategy is responsible for dynamically updating the training + dataset during optimization, typically by resampling points in the domain + based on model behavior (e.g., error-driven refinement). + + All specific refinement strategies should inherit from this class and + implement its abstract methods. + + This class is not meant to be instantiated directly. + """ + + def __init__(self, sample_every, condition_to_update=None): + """ + Initialization of the :class:`BaseRefinement` class. + + :param int sample_every: The number of epochs between successive + refinement steps. + :param condition_to_update: The condition(s) to be updated during + refinement. If ``None``, all conditions associated with a domain are + updated. Default is ``None``. + :type condition_to_update: str | list[str] | tuple[str] + :raises AssertionError: If ``sample_every`` is not a positive integer. + :raises ValueError: If ``condition_to_update``, when provided, is not a + string or an iterable of strings. + """ + # Check consistency + check_positive_integer(sample_every, strict=True) + if condition_to_update is not None: + if isinstance(condition_to_update, str): + condition_to_update = [condition_to_update] + check_consistency([condition_to_update], (list, tuple)) + check_consistency(condition_to_update, str) + + # Initialize attributes + self._condition_to_update = condition_to_update + self.sample_every = sample_every + self._initial_population_size = None + self._dataset = None + + def on_train_start(self, trainer, solver): + """ + This method is called once before training begins and is typically used + to initialize datasets, sampling conditions, or internal state. + + :param Trainer trainer: The trainer managing the training loop. + :param SolverInterface solver: The solver associated with the trainer. + :raise RuntimeError: If the solver is not physics-informed (i.e., does + not implement PINNInterface). + :raise RuntimeError: If any of the specified conditions do not exist in + the problem. + :raise RuntimeError: If any of the specified conditions do not have a + 'domain' attribute for sampling. + """ + # Check solver consistency + if not isinstance(solver, PINNInterface): + raise RuntimeError( + "Refinement strategies require a physics-informed solver. " + f"Got '{type(solver).__name__}'." + ) + + # Initialize conditions to update if not provided + if self._condition_to_update is None: + self._condition_to_update = [ + name + for name, cond in solver.problem.conditions.items() + if hasattr(cond, "domain") + ] + + # Validate conditions and solver + for cond in self._condition_to_update: + + # Check if condition exists in the problem + if cond not in solver.problem.conditions: + raise RuntimeError( + f"Unknown condition '{cond}'. Available conditions: " + f"{list(solver.problem.conditions.keys())}." + ) + + # Check if condition has a domain to sample from + if not hasattr(solver.problem.conditions[cond], "domain"): + raise RuntimeError( + f"Condition '{cond}' has no 'domain' attribute and cannot " + "be used for sampling." + ) + + # Initialize dataset and compute initial population size + self._dataset = trainer.datamodule.train_datasets + self._initial_population_size = { + cond: self.dataset[cond].length + for cond in self._condition_to_update + } + + def on_train_epoch_end(self, trainer, solver): + """ + Apply refinement at the end of a training epoch. + + This method is invoked after each epoch and can update the dataset based + on the current state of the model. + + :param Trainer trainer: The trainer managing the training loop. + :param SolverInterface solver: The solver associated with the trainer. + """ + # Store current epoch + epoch = trainer.current_epoch + + # Sample if it's time to refine + if epoch % self.sample_every == 0 and epoch != 0: + + # Update points for each condition to update + for name in self._condition_to_update: + + current_points = solver.problem.conditions[name].data.input + new_points = self.sample(current_points, name, solver) + solver.problem.conditions[name].data.input = new_points + + @property + def dataset(self): + """ + The training datasets managed by the refinement strategy. + + The dataset is stored as a dictionary whose keys are condition names and + whose values are the corresponding dataset subsets. The content of this + dictionary can be updated dynamically during refinement. + + :return: The mapping between condition names and dataset subsets. + :rtype: dict + """ + return self._dataset + + @property + def initial_population_size(self): + """ + Initial size of the sampled dataset for each condition before any + refinement is applied. + + :return: A mapping between each condition name and its initial number + of sampled points. + :rtype: dict[str, int] + """ + return self._initial_population_size diff --git a/pina/_src/callback/refinement/r3_refinement.py b/pina/_src/callback/refinement/r3_refinement.py index b8bcc7285..21957bcf1 100644 --- a/pina/_src/callback/refinement/r3_refinement.py +++ b/pina/_src/callback/refinement/r3_refinement.py @@ -1,35 +1,34 @@ """Module for the R3Refinement callback.""" import torch -from pina._src.callback.refinement.refinement_interface import ( - RefinementInterface, -) -from pina._src.core.label_tensor import LabelTensor from pina._src.core.utils import check_consistency +from pina._src.core.label_tensor import LabelTensor from pina._src.loss.loss_interface import LossInterface +from pina._src.callback.refinement.base_refinement import BaseRefinement -class R3Refinement(RefinementInterface): +class R3Refinement(BaseRefinement): """ - PINA Implementation of the R3 Refinement Callback. + Refinement strategy based on the R3 (Retain-Resample-Release) algorithm. + + This method adaptively updates collocation points by retaining points with + high residuals, resampling new points in the domain, releasing points with + low residuals. - This callback implements the R3 (Retain-Resample-Release) routine for - sampling new points based on adaptive search. - The algorithm incrementally accumulates collocation points in regions - of high PDE residuals, and releases those with low residuals. - Points are sampled uniformly in all regions where sampling is needed. + The objective is to concentrate sampling in regions where the PDE residual + is large, improving training efficiency and solution accuracy. .. seealso:: - Original Reference: Daw, Arka, et al. *Mitigating Propagation - Failures in Physics-informed Neural Networks - using Retain-Resample-Release (R3) Sampling. (2023)*. + **Original Reference**: Daw, Arka, et al. (2023). + *Mitigating Propagation Failures in Physics-informed Neural Networks + using Retain-Resample-Release (R3) Sampling*. DOI: `10.48550/arXiv.2207.02338 `_ :Example: - >>> r3_callback = R3Refinement(sample_every=5) + >>> r3 = R3Refinement(sample_every=5) """ def __init__( @@ -39,20 +38,22 @@ def __init__( condition_to_update=None, ): """ - Initialization of the :class:`R3Refinement` callback. - - :param int sample_every: The sampling frequency. - :param loss: The loss function to compute the residuals. - Default is :class:`~torch.nn.L1Loss`. - :type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss` - :param condition_to_update: The conditions to update during the - refinement process. If None, all conditions will be updated. - Default is None. - :type condition_to_update: list(str) | tuple(str) | str + Initialization of the :class:`R3Refinement` class. + + :param int sample_every: The number of epochs between successive + refinement steps. + :param residual_loss: The loss used to evaluate residual magnitude. Must + be a subclass of :class:`torch.nn.Module` or + :class:`pina.loss.LossInterface`. + Default is :class:`torch.nn.L1Loss`. + :type residual_loss: LossInterface | torch.nn.modules.loss._Loss + :param condition_to_update: The condition(s) to be updated during + refinement. If ``None``, all conditions associated with a domain are + updated. Default is ``None``. + :type condition_to_update: str | list[str] | tuple[str] :raises ValueError: If the condition_to_update is neither a string nor an iterable of strings. - :raises TypeError: If the residual_loss is not a subclass of - :class:`~torch.nn.Module`. + :raises ValueError: If the residual_loss is not a valid loss class. """ super().__init__(sample_every, condition_to_update) @@ -63,18 +64,17 @@ def __init__( subclass=True, ) - # Save loss function + # Store the loss function for computing residuals during sampling self.loss_fn = residual_loss(reduction="none") def sample(self, current_points, condition_name, solver): """ - Sample new points based on the R3 refinement strategy. + Generate new sample points for a given condition. - :param current_points: The current points in the domain. - :type current_points: LabelTensor | torch.Tensor - :param str condition_name: The name of the condition to update. - :param PINNInterface solver: The solver using this callback. - :return: The new samples generated by the R3 strategy. + :param LabelTensor current_points: The existing points in the domain. + :param str condition_name: The identifier of the condition to refine. + :param SolverInterface solver: The solver used for sampling decisions. + :return: Newly sampled points. :rtype: LabelTensor """ # Retrieve condition and current points @@ -82,7 +82,7 @@ def sample(self, current_points, condition_name, solver): condition = solver.problem.conditions[condition_name] current_points = current_points.to(device).requires_grad_(True) - # Compute residuals for the given condition (averaged over all fields) + # Compute residuals for the given condition target = solver.compute_residual(current_points, condition.equation) residuals = self.loss_fn(target, torch.zeros_like(target)).mean( dim=tuple(range(1, target.ndim)) @@ -94,11 +94,12 @@ def sample(self, current_points, condition_name, solver): num_old_points = self.initial_population_size[condition_name] # Select points with residual above the mean - mask = (residuals > residuals.mean()).flatten() - if mask.any(): - high_residual_pts = current_points[mask] - high_residual_pts.labels = current_points.labels - samples = domain.sample(num_old_points - len(high_residual_pts)) - return LabelTensor.cat([high_residual_pts, samples.to(device)]) - - return domain.sample(num_old_points, "random") + mask = (residuals >= residuals.mean()).flatten() + high_residual_pts = current_points[mask] + high_residual_pts.labels = current_points.labels + + # Sample new points to maintain the initial population size + num_new_pts = max(num_old_points - len(high_residual_pts), 0) + samples = domain.sample(num_new_pts, "random").to(device) + + return LabelTensor.cat([high_residual_pts, samples]) diff --git a/pina/_src/callback/refinement/refinement_interface.py b/pina/_src/callback/refinement/refinement_interface.py index 83ca8d8be..4c32c6556 100644 --- a/pina/_src/callback/refinement/refinement_interface.py +++ b/pina/_src/callback/refinement/refinement_interface.py @@ -1,157 +1,69 @@ -""" -RefinementInterface class for handling the refinement of points in a neural -network training process. -""" +"""Module for the Refinement Interface.""" from abc import ABCMeta, abstractmethod -from lightning.pytorch import Callback -from pina._src.core.utils import check_consistency -from pina._src.solver.physics_informed_solver.pinn_interface import ( - PINNInterface, -) -class RefinementInterface(Callback, metaclass=ABCMeta): +class RefinementInterface(metaclass=ABCMeta): """ - Interface class of Refinement approaches. + Abstract interface for all refinement strategies. """ - def __init__(self, sample_every, condition_to_update=None): - """ - Initializes the RefinementInterface. - - :param int sample_every: The number of epochs between each refinement. - :param condition_to_update: The conditions to update during the - refinement process. If None, all conditions with a domain will be - updated. Default is None. - :type condition_to_update: list(str) | tuple(str) | str - - """ - # check consistency of the input - check_consistency(sample_every, int) - if condition_to_update is not None: - if isinstance(condition_to_update, str): - condition_to_update = [condition_to_update] - if not isinstance(condition_to_update, (list, tuple)): - raise ValueError( - "'condition_to_update' must be iter of strings." - ) - check_consistency(condition_to_update, str) - # store - self.sample_every = sample_every - self._condition_to_update = condition_to_update - self._dataset = None - self._initial_population_size = None - + @abstractmethod def on_train_start(self, trainer, solver): """ - Called when the training begins. It initializes the conditions and - dataset. + This method is called once before training begins and is typically used + to initialize datasets, sampling conditions, or internal state. - :param ~lightning.pytorch.trainer.trainer.Trainer trainer: The trainer - object. - :param ~pina.solver.solver.SolverInterface solver: The solver - object associated with the trainer. - :raises RuntimeError: If the solver is not a PINNInterface. - :raises RuntimeError: If the conditions do not have a domain to sample - from. + :param Trainer trainer: The trainer managing the training loop. + :param SolverInterface solver: The solver associated with the trainer. """ - # check we have valid conditions names - if self._condition_to_update is None: - self._condition_to_update = [ - name - for name, cond in solver.problem.conditions.items() - if hasattr(cond, "domain") - ] - - for cond in self._condition_to_update: - if cond not in solver.problem.conditions: - raise RuntimeError( - f"Condition '{cond}' not found in " - f"{list(solver.problem.conditions.keys())}." - ) - if not hasattr(solver.problem.conditions[cond], "domain"): - raise RuntimeError( - f"Condition '{cond}' does not contain a domain to " - "sample from." - ) - # check solver - if not isinstance(solver, PINNInterface): - raise RuntimeError( - "Refinment strategies are currently implemented only " - "for physics informed based solvers. Please use a Solver " - "inheriting from 'PINNInterface'." - ) - # store dataset - self._dataset = trainer.datamodule.train_dataset - # compute initial population size - self._initial_population_size = self._compute_population_size( - self._condition_to_update - ) - return super().on_train_epoch_start(trainer, solver) + @abstractmethod def on_train_epoch_end(self, trainer, solver): """ - Performs the refinement at the end of each training epoch (if needed). + Apply refinement at the end of a training epoch. + + This method is invoked after each epoch and can update the dataset based + on the current state of the model. - :param ~lightning.pytorch.trainer.trainer.Trainer: The trainer object. - :param PINNInterface solver: The solver object. + :param Trainer trainer: The trainer managing the training loop. + :param SolverInterface solver: The solver associated with the trainer. """ - if (trainer.current_epoch % self.sample_every == 0) and ( - trainer.current_epoch != 0 - ): - self._update_points(solver) - return super().on_train_epoch_end(trainer, solver) @abstractmethod def sample(self, current_points, condition_name, solver): """ - Samples new points based on the condition. + Generate new sample points for a given condition. - :param current_points: Current points in the domain. - :param condition_name: Name of the condition to update. - :param PINNInterface solver: The solver object. - :return: New points sampled based on the R3 strategy. + :param LabelTensor current_points: The existing points in the domain. + :param str condition_name: The identifier of the condition to refine. + :param SolverInterface solver: The solver used for sampling decisions. + :return: Newly sampled points. :rtype: LabelTensor """ @property + @abstractmethod def dataset(self): """ - Returns the dataset for training. - """ - return self._dataset + The training datasets managed by the refinement strategy. - @property - def initial_population_size(self): - """ - Returns the dataset for training size. - """ - return self._initial_population_size - - def _update_points(self, solver): - """ - Performs the refinement of the points. + The dataset is stored as a dictionary whose keys are condition names and + whose values are the corresponding dataset subsets. The content of this + dictionary can be updated dynamically during refinement. - :param PINNInterface solver: The solver object. + :return: The mapping between condition names and dataset subsets. + :rtype: dict """ - new_points = {} - for name in self._condition_to_update: - current_points = self.dataset.conditions_dict[name]["input"] - new_points[name] = { - "input": self.sample(current_points, name, solver) - } - self.dataset.update_data(new_points) - def _compute_population_size(self, conditions): + @property + @abstractmethod + def initial_population_size(self): """ - Computes the number of points in the dataset for each condition. + Initial size of the sampled dataset for each condition before any + refinement is applied. - :param conditions: List of conditions to compute the number of points. - :return: Dictionary with the population size for each condition. - :rtype: dict + :return: A mapping between each condition name and its initial number + of sampled points. + :rtype: dict[str, int] """ - return { - cond: len(self.dataset.conditions_dict[cond]["input"]) - for cond in conditions - } diff --git a/pina/_src/condition/base_condition.py b/pina/_src/condition/base_condition.py index 013c5bf24..939c75e39 100644 --- a/pina/_src/condition/base_condition.py +++ b/pina/_src/condition/base_condition.py @@ -67,7 +67,7 @@ def create_dataloader( """ Create the DataLoader for the condition. - :param Dataset dataset: The dataset for the DataLoader. + :param _ConditionSubset dataset: The dataset for the DataLoader. :param int batch_size: The batch size for the DataLoader. :param bool automatic_batching: Whether to use automatic batching. :param dict kwargs: Additional keyword arguments for the DataLoader. diff --git a/pina/_src/condition/condition_interface.py b/pina/_src/condition/condition_interface.py index 9183d196f..bfeee7685 100644 --- a/pina/_src/condition/condition_interface.py +++ b/pina/_src/condition/condition_interface.py @@ -48,7 +48,7 @@ def create_dataloader( """ Create the DataLoader for the condition. - :param Dataset dataset: The dataset for the DataLoader. + :param _ConditionSubset dataset: The dataset for the DataLoader. :param int batch_size: The batch size for the DataLoader. :param bool automatic_batching: Whether to use automatic batching. :param dict kwargs: Additional keyword arguments for the DataLoader. diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index 4c7ab70c4..4a5b2c66a 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -156,6 +156,8 @@ def __init__( self.problem.move_discretisation_into_conditions() self._check_slit_sizes(train_size, test_size, val_size) + # TODO: singular forms (train_dataset, val_dataset, test_dataset) seem + # to be unused. Clean code. if train_size > 0: self.train_dataset = None else: diff --git a/pina/_src/data/dataset.py b/pina/_src/data/dataset.py index bf2f168e4..dcad84662 100644 --- a/pina/_src/data/dataset.py +++ b/pina/_src/data/dataset.py @@ -5,6 +5,8 @@ from torch_geometric.data import Data from pina._src.core.graph import Graph, LabelBatch +# TODO: the whole file seems to be unused, check if it can be safely deleted. + class PinaDatasetFactory: """ diff --git a/pina/_src/optim/optimizer_interface.py b/pina/_src/optim/optimizer_interface.py index 5f2fbe66a..b60e23624 100644 --- a/pina/_src/optim/optimizer_interface.py +++ b/pina/_src/optim/optimizer_interface.py @@ -1,23 +1,30 @@ -"""Module for the PINA Optimizer.""" +"""Module for the Optimizer Interface.""" from abc import ABCMeta, abstractmethod -class Optimizer(metaclass=ABCMeta): +class OptimizerInterface(metaclass=ABCMeta): """ - Abstract base class for defining an optimizer. All specific optimizers - should inherit form this class and implement the required methods. + Abstract interface for all optimizers. """ - @property @abstractmethod - def instance(self): + def hook(self, parameters): """ - Abstract property to retrieve the optimizer instance. + Execute custom logic associated with the optimizer instance. + + This method is intended to encapsulate any additional behavior that + should be triggered during the optimization process. + + :param dict parameters: The parameters of the model to be optimized. """ + @property @abstractmethod - def hook(self): + def instance(self): """ - Abstract method to define the hook logic for the optimizer. + The underlying optimizer object. + + :return: The optimizer instance. + :rtype: object """ diff --git a/pina/_src/optim/scheduler_interface.py b/pina/_src/optim/scheduler_interface.py index 5ae5d8b99..55951ee0e 100644 --- a/pina/_src/optim/scheduler_interface.py +++ b/pina/_src/optim/scheduler_interface.py @@ -1,23 +1,31 @@ -"""Module for the PINA Scheduler.""" +"""Module for the Scheduler Interface.""" from abc import ABCMeta, abstractmethod -class Scheduler(metaclass=ABCMeta): +class SchedulerInterface(metaclass=ABCMeta): """ - Abstract base class for defining a scheduler. All specific schedulers should - inherit form this class and implement the required methods. + Abstract interface for all schedulers. """ - @property @abstractmethod - def instance(self): + def hook(self, optimizer): """ - Abstract property to retrieve the scheduler instance. + Execute custom logic associated with the scheduler instance. + + This method is intended to encapsulate any additional behavior that + should be triggered during the optimization process. + + :param OptimizerInterface optimizer: The optimizer instance associated + with the scheduler. """ + @property @abstractmethod - def hook(self): + def instance(self): """ - Abstract method to define the hook logic for the scheduler. + The underlying scheduler object. + + :return: The scheduler instance. + :rtype: object """ diff --git a/pina/_src/optim/torch_optimizer.py b/pina/_src/optim/torch_optimizer.py index f01d3b3cb..a37bfbfec 100644 --- a/pina/_src/optim/torch_optimizer.py +++ b/pina/_src/optim/torch_optimizer.py @@ -1,35 +1,46 @@ -"""Module for the PINA Torch Optimizer""" +"""Module for wrapping PyTorch optimizers.""" import torch - from pina._src.core.utils import check_consistency -from pina._src.optim.optimizer_interface import Optimizer +from pina._src.optim.optimizer_interface import OptimizerInterface -class TorchOptimizer(Optimizer): +class TorchOptimizer(OptimizerInterface): """ - A wrapper class for using PyTorch optimizers. + The wrapper class for PyTorch optimizers. + + This class wraps a ``torch.optim.Optimizer`` class and defers its + instantiation until runtime. It enables a consistent interface across + different optimizer backends while leveraging PyTorch’s optimization + algorithms. """ def __init__(self, optimizer_class, **kwargs): """ Initialization of the :class:`TorchOptimizer` class. - :param torch.optim.Optimizer optimizer_class: A - :class:`torch.optim.Optimizer` class. - :param dict kwargs: Additional parameters passed to ``optimizer_class``, - see more + :param torch.optim.Optimizer optimizer_class: The subclass of + ``torch.optim.Optimizer`` to be instantiated. + :param dict kwargs: Additional keyword arguments forwarded to the + optimizer constructor. See more `here `_. + :raises ValueError: If ``optimizer_class`` is not a subclass of + ``torch.optim.Optimizer``. """ + # Check consistency check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True) + # Initialize attributes self.optimizer_class = optimizer_class self.kwargs = kwargs self._optimizer_instance = None def hook(self, parameters): """ - Initialize the optimizer instance with the given parameters. + Execute custom logic associated with the optimizer instance. + + This method is intended to encapsulate any additional behavior that + should be triggered during the optimization process. :param dict parameters: The parameters of the model to be optimized. """ @@ -40,7 +51,7 @@ def hook(self, parameters): @property def instance(self): """ - Get the optimizer instance. + The underlying optimizer object. :return: The optimizer instance. :rtype: torch.optim.Optimizer diff --git a/pina/_src/optim/torch_scheduler.py b/pina/_src/optim/torch_scheduler.py index bf9927836..f33b6020f 100644 --- a/pina/_src/optim/torch_scheduler.py +++ b/pina/_src/optim/torch_scheduler.py @@ -1,34 +1,35 @@ -"""Module for the PINA Torch Optimizer""" - -try: - from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 -except ImportError: - from torch.optim.lr_scheduler import ( - _LRScheduler as LRScheduler, - ) # torch < 2.0 +"""Module for wrapping PyTorch schedulers.""" +from torch.optim.lr_scheduler import LRScheduler from pina._src.core.utils import check_consistency -from pina._src.optim.optimizer_interface import Optimizer -from pina._src.optim.scheduler_interface import Scheduler +from pina._src.optim.optimizer_interface import OptimizerInterface +from pina._src.optim.scheduler_interface import SchedulerInterface -class TorchScheduler(Scheduler): +class TorchScheduler(SchedulerInterface): """ - A wrapper class for using PyTorch schedulers. + The wrapper class for PyTorch schedulers. + + This class wraps a ``torch.optim.lr_scheduler.LRScheduler`` class and defers + its instantiation until runtime, once the optimizer instance is available. """ def __init__(self, scheduler_class, **kwargs): """ Initialization of the :class:`TorchScheduler` class. - :param torch.optim.LRScheduler scheduler_class: A - :class:`torch.optim.LRScheduler` class. - :param dict kwargs: Additional parameters passed to ``scheduler_class``, - see more - `here _`. + :param torch.optim.LRScheduler scheduler_class: The subclass of + ``torch.optim.lr_scheduler.LRScheduler`` to be instantiated. + :param dict kwargs: Additional keyword arguments forwarded to the + scheduler constructor. See more + `here `_. + :raises ValueError: If ``scheduler_class`` is not a subclass of + ``torch.optim.lr_scheduler.LRScheduler``. """ + # Check consistency check_consistency(scheduler_class, LRScheduler, subclass=True) + # Initialize attributes self.scheduler_class = scheduler_class self.kwargs = kwargs self._scheduler_instance = None @@ -37,9 +38,15 @@ def hook(self, optimizer): """ Initialize the scheduler instance with the given parameters. - :param dict parameters: The parameters of the optimizer. + :param OptimizerInterface optimizer: The optimizer instance associated + with the scheduler. + :raises ValueError: If ``optimizer`` is not an instance of + :class:`OptimizerInterface`. """ - check_consistency(optimizer, Optimizer) + # Check consistency + check_consistency(optimizer, OptimizerInterface) + + # Initialize the scheduler instance self._scheduler_instance = self.scheduler_class( optimizer.instance, **self.kwargs ) @@ -47,9 +54,9 @@ def hook(self, optimizer): @property def instance(self): """ - Get the scheduler instance. + The underlying scheduler object. - :return: The scheduelr instance. - :rtype: torch.optim.LRScheduler + :return: The scheduler instance. + :rtype: torch.optim.lr_scheduler.LRScheduler """ return self._scheduler_instance diff --git a/pina/_src/solver/autoregressive_solver/autoregressive_solver.py b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py index 58bf8bdca..31133018a 100644 --- a/pina/_src/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py @@ -53,10 +53,10 @@ def __init__( :param torch.nn.Module loss: The loss function to be minimized. If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is ``None``. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface scheduler: Learning rate scheduler. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/ensemble_solver/ensemble_pinn.py b/pina/_src/solver/ensemble_solver/ensemble_pinn.py index af117d702..743b3db09 100644 --- a/pina/_src/solver/ensemble_solver/ensemble_pinn.py +++ b/pina/_src/solver/ensemble_solver/ensemble_pinn.py @@ -92,10 +92,10 @@ def __init__( :param torch.nn.Module loss: The loss function to be minimized. If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is ``None``. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizers: The optimizers to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface schedulers: Learning rate schedulers. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/ensemble_solver/ensemble_solver_interface.py b/pina/_src/solver/ensemble_solver/ensemble_solver_interface.py index ed0fc2d29..0134e3a98 100644 --- a/pina/_src/solver/ensemble_solver/ensemble_solver_interface.py +++ b/pina/_src/solver/ensemble_solver/ensemble_solver_interface.py @@ -61,10 +61,10 @@ def __init__( :param BaseProblem problem: The problem to be solved. :param torch.nn.Module models: The neural network models to be used. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizers: The optimizers to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface schedulers: Learning rate schedulers. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/ensemble_solver/ensemble_supervised.py b/pina/_src/solver/ensemble_solver/ensemble_supervised.py index e98ab7ed1..f2e26a5f2 100644 --- a/pina/_src/solver/ensemble_solver/ensemble_supervised.py +++ b/pina/_src/solver/ensemble_solver/ensemble_supervised.py @@ -81,10 +81,10 @@ def __init__( :param torch.nn.Module loss: The loss function to be minimized. If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is ``None``. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizers: The optimizers to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface schedulers: Learning rate schedulers. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/garom.py b/pina/_src/solver/garom.py index 29b1c67ac..d476c2d3b 100644 --- a/pina/_src/solver/garom.py +++ b/pina/_src/solver/garom.py @@ -48,18 +48,18 @@ def __init__( :param torch.nn.Module loss: The loss function to be minimized. If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1`` is used. Default is ``None``. - :param Optimizer optimizer_generator: The optimizer for the generator. - If ``None``, the :class:`torch.optim.Adam` optimizer is used. - Default is ``None``. - :param Optimizer optimizer_discriminator: The optimizer for the + :param OptimizerInterface optimizer_generator: The optimizer for the + generator. If ``None``, the :class:`torch.optim.Adam` optimizer is + used. Default is ``None``. + :param OptimizerInterface optimizer_discriminator: The optimizer for the discriminator. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler_generator: The learning rate scheduler for - the generator. + :param SchedulerInterface scheduler_generator: The learning rate + scheduler for the generator. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. - :param Scheduler scheduler_discriminator: The learning rate scheduler - for the discriminator. + :param SchedulerInterface scheduler_discriminator: The learning rate + scheduler for the discriminator. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param float gamma: Ratio of expected loss for generator and @@ -328,7 +328,7 @@ def optimizer_generator(self): The optimizer for the generator. :return: The optimizer for the generator. - :rtype: Optimizer + :rtype: OptimizerInterface """ return self.optimizers[0] @@ -338,7 +338,7 @@ def optimizer_discriminator(self): The optimizer for the discriminator. :return: The optimizer for the discriminator. - :rtype: Optimizer + :rtype: OptimizerInterface """ return self.optimizers[1] @@ -348,7 +348,7 @@ def scheduler_generator(self): The scheduler for the generator. :return: The scheduler for the generator. - :rtype: Scheduler + :rtype: SchedulerInterface """ return self.schedulers[0] @@ -358,6 +358,6 @@ def scheduler_discriminator(self): The scheduler for the discriminator. :return: The scheduler for the discriminator. - :rtype: Scheduler + :rtype: SchedulerInterface """ return self.schedulers[1] diff --git a/pina/_src/solver/physics_informed_solver/causal_pinn.py b/pina/_src/solver/physics_informed_solver/causal_pinn.py index cfcbbea20..c061b783f 100644 --- a/pina/_src/solver/physics_informed_solver/causal_pinn.py +++ b/pina/_src/solver/physics_informed_solver/causal_pinn.py @@ -82,10 +82,10 @@ def __init__( inherit from at least :class:`~pina.problem.time_dependent_problem.TimeDependentProblem`. :param torch.nn.Module model: The neural network model to be used. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param torch.optim.LRScheduler scheduler: Learning rate scheduler. + :param SchedulerInterface scheduler: Learning rate scheduler. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/physics_informed_solver/competitive_pinn.py b/pina/_src/solver/physics_informed_solver/competitive_pinn.py index 42096fa64..1b946e26f 100644 --- a/pina/_src/solver/physics_informed_solver/competitive_pinn.py +++ b/pina/_src/solver/physics_informed_solver/competitive_pinn.py @@ -73,18 +73,18 @@ def __init__( :param torch.nn.Module discriminator: The discriminator to be used. If ``None``, the discriminator is a deepcopy of the ``model``. Default is ``None``. - :param torch.optim.Optimizer optimizer_model: The optimizer of the + :param OptimizerInterface optimizer_model: The optimizer of the ``model``. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param torch.optim.Optimizer optimizer_discriminator: The optimizer of + :param OptimizerInterface optimizer_discriminator: The optimizer of the ``discriminator``. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler_model: Learning rate scheduler for the - ``model``. + :param SchedulerInterface scheduler_model: Learning rate scheduler for + the ``model``. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. - :param Scheduler scheduler_discriminator: Learning rate scheduler for - the ``discriminator``. + :param SchedulerInterface scheduler_discriminator: Learning rate + scheduler for the ``discriminator``. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. @@ -184,7 +184,7 @@ def configure_optimizers(self): Optimizer configuration. :return: The optimizers and the schedulers - :rtype: tuple[list[Optimizer], list[Scheduler]] + :rtype: tuple[list[OptimizerInterface], list[SchedulerInterface]] """ # If the problem is an InverseProblem, add the unknown parameters # to the parameters to be optimized @@ -238,7 +238,7 @@ def optimizer_model(self): The optimizer associated to the model. :return: The optimizer for the model. - :rtype: Optimizer + :rtype: OptimizerInterface """ return self.optimizers[0] @@ -248,7 +248,7 @@ def optimizer_discriminator(self): The optimizer associated to the discriminator. :return: The optimizer for the discriminator. - :rtype: Optimizer + :rtype: OptimizerInterface """ return self.optimizers[1] @@ -258,7 +258,7 @@ def scheduler_model(self): The scheduler associated to the model. :return: The scheduler for the model. - :rtype: Scheduler + :rtype: SchedulerInterface """ return self.schedulers[0] @@ -268,6 +268,6 @@ def scheduler_discriminator(self): The scheduler associated to the discriminator. :return: The scheduler for the discriminator. - :rtype: Scheduler + :rtype: SchedulerInterface """ return self.schedulers[1] diff --git a/pina/_src/solver/physics_informed_solver/gradient_pinn.py b/pina/_src/solver/physics_informed_solver/gradient_pinn.py index 4ee2b3089..72798b10a 100644 --- a/pina/_src/solver/physics_informed_solver/gradient_pinn.py +++ b/pina/_src/solver/physics_informed_solver/gradient_pinn.py @@ -74,10 +74,10 @@ def __init__( :class:`~pina.problem.spatial_problem.SpatialProblem` to compute the gradient of the loss. :param torch.nn.Module model: The neural network model to be used. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface scheduler: Learning rate scheduler. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/physics_informed_solver/pinn.py b/pina/_src/solver/physics_informed_solver/pinn.py index 59b61214e..47ffa6d6d 100644 --- a/pina/_src/solver/physics_informed_solver/pinn.py +++ b/pina/_src/solver/physics_informed_solver/pinn.py @@ -63,10 +63,10 @@ def __init__( :param BaseProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface scheduler: Learning rate scheduler. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. @@ -117,7 +117,7 @@ def configure_optimizers(self): Optimizer configuration for the PINN solver. :return: The optimizers and the schedulers - :rtype: tuple[list[Optimizer], list[Scheduler]] + :rtype: tuple[list[OptimizerInterface], list[SchedulerInterface]] """ # If the problem is an InverseProblem, add the unknown parameters # to the parameters to be optimized. diff --git a/pina/_src/solver/physics_informed_solver/rba_pinn.py b/pina/_src/solver/physics_informed_solver/rba_pinn.py index 5c7821120..e1d754f88 100644 --- a/pina/_src/solver/physics_informed_solver/rba_pinn.py +++ b/pina/_src/solver/physics_informed_solver/rba_pinn.py @@ -81,10 +81,10 @@ def __init__( :param BaseProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface scheduler: Learning rate scheduler. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/physics_informed_solver/self_adaptive_pinn.py b/pina/_src/solver/physics_informed_solver/self_adaptive_pinn.py index 983eb2966..c8217a892 100644 --- a/pina/_src/solver/physics_informed_solver/self_adaptive_pinn.py +++ b/pina/_src/solver/physics_informed_solver/self_adaptive_pinn.py @@ -125,19 +125,19 @@ def __init__( :param torch.nn.Module model: The model to be used. :param torch.nn.Module weight_function: The Self-Adaptive mask model. Default is ``torch.nn.Sigmoid()``. - :param Optimizer optimizer_model: The optimizer of the ``model``. - If ``None``, the :class:`torch.optim.Adam` optimizer is used. - Default is ``None``. - :param Optimizer optimizer_weights: The optimizer of the + :param OptimizerInterface optimizer_model: The optimizer of the + ``model``. If ``None``, the :class:`torch.optim.Adam` optimizer is + used. Default is ``None``. + :param OptimizerInterface optimizer_weights: The optimizer of the ``weight_function``. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler_model: Learning rate scheduler for the - ``model``. + :param SchedulerInterface scheduler_model: Learning rate scheduler for + the ``model``. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. - :param Scheduler scheduler_weights: Learning rate scheduler for the - ``weight_function``. + :param SchedulerInterface scheduler_weights: Learning rate scheduler for + the ``weight_function``. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. @@ -296,7 +296,7 @@ def configure_optimizers(self): Optimizer configuration. :return: The optimizers and the schedulers - :rtype: tuple[list[Optimizer], list[Scheduler]] + :rtype: tuple[list[OptimizerInterface], list[SchedulerInterface]] """ # Hook the optimizers to the models self.optimizer_model.hook(self.model.parameters()) @@ -421,7 +421,7 @@ def scheduler_model(self): The scheduler associated to the model. :return: The scheduler for the model. - :rtype: Scheduler + :rtype: SchedulerInterface """ return self.schedulers[0] @@ -431,7 +431,7 @@ def scheduler_weights(self): The scheduler associated to the mask model. :return: The scheduler for the mask model. - :rtype: Scheduler + :rtype: SchedulerInterface """ return self.schedulers[1] @@ -441,7 +441,7 @@ def optimizer_model(self): Returns the optimizer associated to the model. :return: The optimizer for the model. - :rtype: Optimizer + :rtype: OptimizerInterface """ return self.optimizers[0] @@ -451,6 +451,6 @@ def optimizer_weights(self): The optimizer associated to the mask model. :return: The optimizer for the mask model. - :rtype: Optimizer + :rtype: OptimizerInterface """ return self.optimizers[1] diff --git a/pina/_src/solver/solver.py b/pina/_src/solver/solver.py index 571892f05..3d1f8de36 100644 --- a/pina/_src/solver/solver.py +++ b/pina/_src/solver/solver.py @@ -7,8 +7,8 @@ from torch._dynamo import OptimizedModule from pina._src.problem.base_problem import BaseProblem from pina._src.problem.inverse_problem import InverseProblem -from pina._src.optim.optimizer_interface import Optimizer -from pina._src.optim.scheduler_interface import Scheduler +from pina._src.optim.optimizer_interface import OptimizerInterface +from pina._src.optim.scheduler_interface import SchedulerInterface from pina._src.optim.torch_optimizer import TorchOptimizer from pina._src.optim.torch_scheduler import TorchScheduler from pina._src.weighting.weighting_interface import WeightingInterface @@ -316,7 +316,7 @@ def default_torch_optimizer(): Set the default optimizer to :class:`torch.optim.Adam`. :return: The default optimizer. - :rtype: Optimizer + :rtype: OptimizerInterface """ return TorchOptimizer(torch.optim.Adam, lr=0.001) @@ -327,7 +327,7 @@ def default_torch_scheduler(): :class:`torch.optim.lr_scheduler.ConstantLR`. :return: The default scheduler. - :rtype: Scheduler + :rtype: SchedulerInterface """ return TorchScheduler(torch.optim.lr_scheduler.ConstantLR, factor=1.0) @@ -381,10 +381,10 @@ def __init__( :param BaseProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: The scheduler to be used. + :param SchedulerInterface scheduler: The scheduler to be used. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. @@ -402,9 +402,9 @@ def __init__( # check consistency of models argument and encapsulate in list check_consistency(model, torch.nn.Module) # check scheduler consistency and encapsulate in list - check_consistency(scheduler, Scheduler) + check_consistency(scheduler, SchedulerInterface) # check optimizer consistency and encapsulate in list - check_consistency(optimizer, Optimizer) + check_consistency(optimizer, OptimizerInterface) # initialize the model (needed by Lightining to go to different devices) self._pina_models = torch.nn.ModuleList([model]) @@ -427,7 +427,7 @@ def configure_optimizers(self): Optimizer configuration for the solver. :return: The optimizer and the scheduler - :rtype: tuple[list[Optimizer], list[Scheduler]] + :rtype: tuple[list[OptimizerInterface], list[SchedulerInterface]] """ self.optimizer.hook(self.model.parameters()) if isinstance(self.problem, InverseProblem): @@ -458,7 +458,7 @@ def scheduler(self): The scheduler used for training. :return: The scheduler used for training. - :rtype: Scheduler + :rtype: SchedulerInterface """ return self._pina_schedulers[0] @@ -468,7 +468,7 @@ def optimizer(self): The optimizer used for training. :return: The optimizer used for training. - :rtype: Optimizer + :rtype: OptimizerInterface """ return self._pina_optimizers[0] @@ -493,10 +493,10 @@ def __init__( :param BaseProblem problem: The problem to be solved. :param models: The neural network models to be used. :type model: list[torch.nn.Module] | tuple[torch.nn.Module] - :param list[Optimizer] optimizers: The optimizers to be used. + :param list[OptimizerInterface] optimizers: The optimizers to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used for all models. Default is ``None``. - :param list[Scheduler] schedulers: The schedulers to be used. + :param list[SchedulerInterface] schedulers: The schedulers to be used. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used for all the models. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. @@ -548,10 +548,10 @@ def __init__( check_consistency(models, torch.nn.Module) # check scheduler consistency and encapsulate in list - check_consistency(schedulers, Scheduler) + check_consistency(schedulers, SchedulerInterface) # check optimizer consistency and encapsulate in list - check_consistency(optimizers, Optimizer) + check_consistency(optimizers, OptimizerInterface) # check length consistency optimizers if len(models) != len(optimizers): @@ -598,7 +598,7 @@ def configure_optimizers(self): Optimizer configuration for the solver. :return: The optimizer and the scheduler - :rtype: tuple[list[Optimizer], list[Scheduler]] + :rtype: tuple[list[OptimizerInterface], list[SchedulerInterface]] """ for optimizer, scheduler, model in zip( self.optimizers, self.schedulers, self.models @@ -627,7 +627,7 @@ def optimizers(self): The optimizers used for training. :return: The optimizers used for training. - :rtype: list[Optimizer] + :rtype: list[OptimizerInterface] """ return self._pina_optimizers @@ -637,6 +637,6 @@ def schedulers(self): The schedulers used for training. :return: The schedulers used for training. - :rtype: list[Scheduler] + :rtype: list[SchedulerInterface] """ return self._pina_schedulers diff --git a/pina/_src/solver/supervised_solver/reduced_order_model.py b/pina/_src/solver/supervised_solver/reduced_order_model.py index 3687a3e2b..585d0ef90 100644 --- a/pina/_src/solver/supervised_solver/reduced_order_model.py +++ b/pina/_src/solver/supervised_solver/reduced_order_model.py @@ -106,10 +106,10 @@ def __init__( :param torch.nn.Module loss: The loss function to be minimized. If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface scheduler: Learning rate scheduler. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/_src/solver/supervised_solver/supervised.py b/pina/_src/solver/supervised_solver/supervised.py index cdbddffca..e7ee6d6e6 100644 --- a/pina/_src/solver/supervised_solver/supervised.py +++ b/pina/_src/solver/supervised_solver/supervised.py @@ -50,10 +50,10 @@ def __init__( :param torch.nn.Module loss: The loss function to be minimized. If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. - :param Optimizer optimizer: The optimizer to be used. + :param OptimizerInterface optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. + :param SchedulerInterface scheduler: Learning rate scheduler. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index 2f6d5a0a2..a6b2e3973 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -8,17 +8,39 @@ __all__ = [ "SwitchOptimizer", "SwitchScheduler", - "NormalizerDataCallback", + "DataNormalizer", "PINAProgressBar", "MetricTracker", + "RefinementInterface", + "BaseRefinement", "R3Refinement", ] -from pina._src.callback.optim.switch_optimizer import SwitchOptimizer -from pina._src.callback.optim.switch_scheduler import SwitchScheduler -from pina._src.callback.processing.normalizer_data_callback import ( - NormalizerDataCallback, -) from pina._src.callback.processing.pina_progress_bar import PINAProgressBar from pina._src.callback.processing.metric_tracker import MetricTracker +from pina._src.callback.processing.data_normalizer import DataNormalizer +from pina._src.callback.optim.switch_optimizer import SwitchOptimizer +from pina._src.callback.optim.switch_scheduler import SwitchScheduler +from pina._src.callback.refinement.base_refinement import BaseRefinement from pina._src.callback.refinement.r3_refinement import R3Refinement +from pina._src.callback.refinement.refinement_interface import ( + RefinementInterface, +) + +# Back-compatibility with version 0.2, to be removed soon +import warnings + +_DEPRECATED_IMPORTS = {"NormalizerDataCallback": "DataNormalizer"} + + +def __getattr__(name): + if name in _DEPRECATED_IMPORTS: + + warnings.warn( + f"Importing '{name}' from 'pina.callback' is deprecated; use " + f"pina.callback.{_DEPRECATED_IMPORTS[name]} instead.", + DeprecationWarning, + stacklevel=2, + ) + + return globals()[_DEPRECATED_IMPORTS[name]] diff --git a/pina/optim/__init__.py b/pina/optim/__init__.py index 682b6225e..f88b85e7a 100644 --- a/pina/optim/__init__.py +++ b/pina/optim/__init__.py @@ -1,13 +1,34 @@ """Module for the Optimizers and Schedulers.""" __all__ = [ - "Optimizer", + "OptimizerInterface", + "SchedulerInterface", "TorchOptimizer", - "Scheduler", "TorchScheduler", ] -from pina._src.optim.optimizer_interface import Optimizer +from pina._src.optim.optimizer_interface import OptimizerInterface +from pina._src.optim.scheduler_interface import SchedulerInterface from pina._src.optim.torch_optimizer import TorchOptimizer -from pina._src.optim.scheduler_interface import Scheduler from pina._src.optim.torch_scheduler import TorchScheduler + +# Back-compatibility with version 0.2, to be removed soon +import warnings + +_DEPRECATED_IMPORTS = { + "Optimizer": "OptimizerInterface", + "Scheduler": "SchedulerInterface", +} + + +def __getattr__(name): + if name in _DEPRECATED_IMPORTS: + + warnings.warn( + f"Importing '{name}' from 'pina.optim' is deprecated; use " + f"pina.optim.{_DEPRECATED_IMPORTS[name]} instead.", + DeprecationWarning, + stacklevel=2, + ) + + return globals()[_DEPRECATED_IMPORTS[name]] diff --git a/tests/test_callback/test_data_normalizer.py b/tests/test_callback/test_data_normalizer.py new file mode 100644 index 000000000..ea28631c5 --- /dev/null +++ b/tests/test_callback/test_data_normalizer.py @@ -0,0 +1,192 @@ +import torch +import pytest +from pina import Trainer, LabelTensor, Condition +from pina.solver import SupervisedSolver +from pina.callback import DataNormalizer +from pina.problem import BaseProblem +from pina.model import FeedForward +from pina.graph import RadiusGraph + + +# Tensor-based problem +class TensorProblem(BaseProblem): + input_variables = ["x", "y"] + output_variables = ["u"] + conditions = { + "data1": Condition(input=torch.rand(20, 2), target=torch.rand(20, 1)), + "data2": Condition(input=torch.rand(20, 2), target=torch.rand(20, 1)), + } + + +# LabelTensor-based problem +class LabelTensorProblem(BaseProblem): + input_variables = ["x", "y"] + output_variables = ["u"] + conditions = { + "data1": Condition( + input=LabelTensor(torch.rand(20, 2), ["x", "y"]), + target=LabelTensor(torch.rand(20, 1), ["u"]), + ), + "data2": Condition( + input=LabelTensor(torch.rand(20, 2), ["x", "y"]), + target=LabelTensor(torch.rand(20, 1), ["u"]), + ), + } + + +# Graph-based problem for testing unsupported dataset case +input_graph = [RadiusGraph(radius=0.5, pos=torch.rand(10, 2)) for _ in range(5)] +target_tensor = torch.rand(5, 1) + + +class GraphProblem(BaseProblem): + + input_variables = ["x", "y"] + output_variables = ["u"] + conditions = {"data1": Condition(input=input_graph, target=target_tensor)} + + +# Mapping from stage to dataset names +stage_map = { + "train": ["train_datasets"], + "validate": ["val_datasets"], + "test": ["test_datasets"], + "all": ["train_datasets", "val_datasets", "test_datasets"], +} + + +@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) +@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize("apply_to", ["input", "target"]) +@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) +def test_constructor(scale_fn, shift_fn, apply_to, stage): + DataNormalizer( + scale_fn=scale_fn, shift_fn=shift_fn, stage=stage, apply_to=apply_to + ) + + # Should fail if scale_fn is not Callable + with pytest.raises(ValueError): + DataNormalizer(scale_fn=1) + + # Should fail if shift_fn is not Callable + with pytest.raises(ValueError): + DataNormalizer(shift_fn=1) + + # Should fail if apply_to is invalid + with pytest.raises(ValueError): + DataNormalizer(apply_to="invalid") + + # Should fail if stage is invalid + with pytest.raises(ValueError): + DataNormalizer(stage="invalid") + + +@pytest.mark.parametrize("apply_to", ["input", "target"]) +@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) +@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) +@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_routine(apply_to, stage, scale_fn, shift_fn, use_lt): + + # Initialize problem, model and solver + problem = LabelTensorProblem() if use_lt else TensorProblem() + model = FeedForward( + len(problem.input_variables), len(problem.output_variables) + ) + solver = SupervisedSolver(problem=problem, model=model, use_lt=use_lt) + + # Initialize the callback + callback = DataNormalizer( + scale_fn=scale_fn, + shift_fn=shift_fn, + stage=stage, + apply_to=apply_to, + ) + + # Initialize the trainer + trainer = Trainer( + solver=solver, + callbacks=callback, + accelerator="cpu", + max_epochs=3, + train_size=0.6, + val_size=0.2, + test_size=0.2, + ) + + # Run the training and testing routines + trainer.train() + trainer.test() + + # Store datasets to check normalization + datasets = { + "train_datasets": trainer.datamodule.train_datasets, + "val_datasets": trainer.datamodule.val_datasets, + "test_datasets": trainer.datamodule.test_datasets, + } + + # Save the expected normalized datasets for each stage + expected_normalized_datasets = stage_map[stage] + + # Check computed normalizer exists for all input-target conditions + for name in solver.problem.conditions.keys(): + assert name in callback.normalizer + assert "scale" in callback.normalizer[name] + assert "shift" in callback.normalizer[name] + + # Check normalized datasets + for ds_name, dataset in datasets.items(): + for c_name in callback.normalizer.keys(): + + # Extract the data and container for the current condition + points = getattr(dataset[c_name].condition, apply_to) + + # Check normalization parameters are correct for normalized datasets + if ds_name in expected_normalized_datasets: + expected_shift = shift_fn(points) + + # The expected shift should be close to zero after normalization + assert torch.isclose( + expected_shift, + torch.zeros_like(expected_shift), + atol=1e-5, + ) + + # The expected scale should be close to one after normalization + if scale_fn is torch.std: + expected_scale = scale_fn(points) + + assert torch.isclose( + expected_scale, + torch.ones_like(expected_scale), + atol=1e-5, + ) + + # Should fail if the dataset is graph-based and therefore unsupported + with pytest.raises(NotImplementedError): + + # Initialize problem, model and solver with graph-based problem + model = FeedForward( + len(GraphProblem.input_variables), + len(GraphProblem.output_variables), + ) + solver = SupervisedSolver(problem=GraphProblem(), model=model) + + # Initialize the callback + callback = DataNormalizer( + scale_fn=scale_fn, + shift_fn=shift_fn, + stage=stage, + apply_to=apply_to, + ) + + # Initialize the trainer + trainer = Trainer( + solver=solver, + callbacks=callback, + accelerator="cpu", + max_epochs=3, + ) + + # Run the training routine to trigger the error + trainer.train() diff --git a/tests/test_callback/test_metric_tracker.py b/tests/test_callback/test_metric_tracker.py index 49b904885..387a98ac1 100644 --- a/tests/test_callback/test_metric_tracker.py +++ b/tests/test_callback/test_metric_tracker.py @@ -1,38 +1,71 @@ +import torch +import pytest from pina.solver import PINN -from pina.trainer import Trainer from pina.model import FeedForward from pina.callback import MetricTracker -from pina.problem.zoo import Poisson2DSquareProblem as Poisson - -# make the problem -poisson_problem = Poisson() -n = 10 -poisson_problem.discretise_domain(n, "grid", domains="boundary") -poisson_problem.discretise_domain(n, "grid", domains="D") -model = FeedForward( - len(poisson_problem.input_variables), len(poisson_problem.output_variables) +from pina import Trainer, Condition, LabelTensor +from pina.problem.zoo import Poisson2DSquareProblem + +# Initialize the problem +problem = Poisson2DSquareProblem() +problem.discretise_domain(10, "random") +problem.conditions["data"] = Condition( + input=LabelTensor(torch.randn(10, 2), labels=["x", "y"]), + target=LabelTensor(torch.randn(10, 1), labels=["u"]), +) + +# Initialize the model and solver +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) +solver = PINN(problem=problem, model=model) + + +@pytest.mark.parametrize( + "metrics_to_track", [["D_loss", "train_loss"], "data_loss", None] ) +def test_constructor(metrics_to_track): + MetricTracker(metrics_to_track=metrics_to_track) -# make the solver -solver = PINN(problem=poisson_problem, model=model) + # Should fail if metrics_to_track is not a string or list of strings + with pytest.raises(ValueError): + MetricTracker(metrics_to_track=123) -def test_metric_tracker_constructor(): - MetricTracker() +@pytest.mark.parametrize( + "metrics_to_track", [["D_loss", "train_loss"], "data_loss", None] +) +@pytest.mark.parametrize("batch_size", [None, 8]) +def test_routine(metrics_to_track, batch_size): + # Initialize the callback + callback = MetricTracker(metrics_to_track=metrics_to_track) -def test_metric_tracker_routine(): - # make the trainer + # Convert to list if a single string is provided + if isinstance(metrics_to_track, str): + metrics_to_track = [metrics_to_track] + + # Initialize the trainer with the callback and train the model trainer = Trainer( solver=solver, - callbacks=[MetricTracker()], + callbacks=[callback], accelerator="cpu", max_epochs=5, + batch_size=batch_size, log_every_n_steps=1, ) trainer.train() - # get the tracked metrics - metrics = trainer.callbacks[0].metrics - # assert the logged metrics are correct - logged_metrics = sorted(list(metrics.keys())) - assert logged_metrics == ["train_loss"] + + # Get the logged metrics from the callback + logged_metrics = sorted(list(trainer.callbacks[0].metrics.keys())) + + # Define the expected metrics + expected_metrics = metrics_to_track or ["train_loss"] + + # If a batch size is provided, expand metric names to match convention + if batch_size is not None: + expected_metrics = [ + f"{metric}_{suffix}" + for metric in expected_metrics + for suffix in ("step", "epoch") + ] + + assert sorted(logged_metrics) == sorted(expected_metrics) diff --git a/tests/test_callback/test_normalizer_data_callback.py b/tests/test_callback/test_normalizer_data_callback.py deleted file mode 100644 index 431171bd7..000000000 --- a/tests/test_callback/test_normalizer_data_callback.py +++ /dev/null @@ -1,244 +0,0 @@ -import torch -import pytest -from copy import deepcopy - -from pina import Trainer, LabelTensor, Condition -from pina.solver import SupervisedSolver -from pina.model import FeedForward -from pina.callback import NormalizerDataCallback -from pina.problem import BaseProblem -from pina.problem.zoo import Poisson2DSquareProblem as Poisson -from pina.solver import PINN -from pina.graph import RadiusGraph - -# for checking normalization -stage_map = { - "train": ["train_dataset"], - "validate": ["val_dataset"], - "test": ["test_dataset"], - "all": ["train_dataset", "val_dataset", "test_dataset"], -} - -input_1 = torch.rand(20, 2) * 10 -target_1 = torch.rand(20, 1) * 10 -input_2 = torch.rand(20, 2) * 5 -target_2 = torch.rand(20, 1) * 5 - - -class LabelTensorProblem(BaseProblem): - input_variables = ["u_0", "u_1"] - output_variables = ["u"] - conditions = { - "data1": Condition( - input=LabelTensor(input_1, ["u_0", "u_1"]), - target=LabelTensor(target_1, ["u"]), - ), - "data2": Condition( - input=LabelTensor(input_2, ["u_0", "u_1"]), - target=LabelTensor(target_2, ["u"]), - ), - } - - -class TensorProblem(BaseProblem): - input_variables = ["u_0", "u_1"] - output_variables = ["u"] - conditions = { - "data1": Condition(input=input_1, target=target_1), - "data2": Condition(input=input_2, target=target_2), - } - - -input_graph = [RadiusGraph(radius=0.5, pos=torch.rand(10, 2)) for _ in range(5)] -output_graph = torch.rand(5, 1) - - -class GraphProblem(BaseProblem): - input_variables = ["u_0", "u_1"] - output_variables = ["u"] - conditions = { - "data": Condition(input=input_graph, target=output_graph), - } - - -supervised_solver_no_lt = SupervisedSolver( - problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False -) -supervised_solver_lt = SupervisedSolver( - problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=True -) - -poisson_problem = Poisson() -poisson_problem.conditions["data"] = Condition( - input=LabelTensor(torch.rand(20, 2) * 10, ["x", "y"]), - target=LabelTensor(torch.rand(20, 1) * 10, ["u"]), -) - - -@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) -@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) -@pytest.mark.parametrize("apply_to", ["input", "target"]) -@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) -def test_init(scale_fn, shift_fn, apply_to, stage): - normalizer = NormalizerDataCallback( - scale_fn=scale_fn, shift_fn=shift_fn, apply_to=apply_to, stage=stage - ) - assert normalizer.scale_fn == scale_fn - assert normalizer.shift_fn == shift_fn - assert normalizer.apply_to == apply_to - assert normalizer.stage == stage - - -def test_init_invalid_scale(): - with pytest.raises(ValueError): - NormalizerDataCallback(scale_fn=1) - - -def test_init_invalid_shift(): - with pytest.raises(ValueError): - NormalizerDataCallback(shift_fn=1) - - -@pytest.mark.parametrize("invalid_apply_to", ["inputt", "targett", 1]) -def test_init_invalid_apply_to(invalid_apply_to): - with pytest.raises(ValueError): - NormalizerDataCallback(apply_to=invalid_apply_to) - - -@pytest.mark.parametrize("invalid_stage", ["trainn", "validatee", 1]) -def test_init_invalid_stage(invalid_stage): - with pytest.raises(ValueError): - NormalizerDataCallback(stage=invalid_stage) - - -@pytest.mark.parametrize( - "solver", [supervised_solver_lt, supervised_solver_no_lt] -) -@pytest.mark.parametrize( - "fn", [[torch.std, torch.mean], [torch.var, torch.median]] -) -@pytest.mark.parametrize("apply_to", ["input", "target"]) -@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) -def test_setup(solver, fn, stage, apply_to): - scale_fn, shift_fn = fn - trainer = Trainer( - solver=solver, - callbacks=NormalizerDataCallback( - scale_fn=scale_fn, shift_fn=shift_fn, stage=stage, apply_to=apply_to - ), - max_epochs=1, - train_size=0.4, - val_size=0.3, - test_size=0.3, - shuffle=False, - ) - trainer_copy = deepcopy(trainer) - trainer_copy.data_module.setup("fit") - trainer_copy.data_module.setup("test") - trainer.train() - trainer.test() - - normalizer = trainer.callbacks[0].normalizer - - for cond in ["data1", "data2"]: - scale = scale_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][ - apply_to - ] - ) - shift = shift_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][ - apply_to - ] - ) - assert "scale" in normalizer[cond] - assert "shift" in normalizer[cond] - assert normalizer[cond]["scale"] - scale < 1e-5 - assert normalizer[cond]["shift"] - shift < 1e-5 - for ds_name in stage_map[stage]: - dataset = getattr(trainer.data_module, ds_name, None) - old_dataset = getattr(trainer_copy.data_module, ds_name, None) - current_points = dataset.conditions_dict[cond][apply_to] - old_points = old_dataset.conditions_dict[cond][apply_to] - expected = (old_points - shift) / scale - assert torch.allclose(current_points, expected) - - -@pytest.mark.parametrize( - "fn", [[torch.std, torch.mean], [torch.var, torch.median]] -) -@pytest.mark.parametrize("apply_to", ["input"]) -@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) -def test_setup_pinn(fn, stage, apply_to): - scale_fn, shift_fn = fn - pinn = PINN( - problem=poisson_problem, - model=FeedForward(2, 1), - ) - poisson_problem.discretise_domain(n=10) - trainer = Trainer( - solver=pinn, - callbacks=NormalizerDataCallback( - scale_fn=scale_fn, - shift_fn=shift_fn, - stage=stage, - apply_to=apply_to, - ), - max_epochs=1, - train_size=0.4, - val_size=0.3, - test_size=0.3, - shuffle=False, - ) - - trainer_copy = deepcopy(trainer) - trainer_copy.data_module.setup("fit") - trainer_copy.data_module.setup("test") - trainer.train() - trainer.test() - - conditions = trainer.callbacks[0].normalizer.keys() - assert "data" in conditions - assert len(conditions) == 1 - normalizer = trainer.callbacks[0].normalizer - cond = "data" - - scale = scale_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] - ) - shift = shift_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] - ) - assert "scale" in normalizer[cond] - assert "shift" in normalizer[cond] - assert normalizer[cond]["scale"] - scale < 1e-5 - assert normalizer[cond]["shift"] - shift < 1e-5 - for ds_name in stage_map[stage]: - dataset = getattr(trainer.data_module, ds_name, None) - old_dataset = getattr(trainer_copy.data_module, ds_name, None) - current_points = dataset.conditions_dict[cond][apply_to] - old_points = old_dataset.conditions_dict[cond][apply_to] - expected = (old_points - shift) / scale - assert torch.allclose(current_points, expected) - - -def test_setup_graph_dataset(): - solver = SupervisedSolver( - problem=GraphProblem(), model=FeedForward(2, 1), use_lt=False - ) - trainer = Trainer( - solver=solver, - callbacks=NormalizerDataCallback( - scale_fn=torch.std, - shift_fn=torch.mean, - stage="all", - apply_to="input", - ), - max_epochs=1, - train_size=0.4, - val_size=0.3, - test_size=0.3, - shuffle=False, - ) - with pytest.raises(NotImplementedError): - trainer.train() diff --git a/tests/test_callback/test_pina_progress_bar.py b/tests/test_callback/test_pina_progress_bar.py index 8956ebaf0..9ad2b0dc4 100644 --- a/tests/test_callback/test_pina_progress_bar.py +++ b/tests/test_callback/test_pina_progress_bar.py @@ -1,34 +1,84 @@ +import torch +import pytest from pina.solver import PINN -from pina.trainer import Trainer from pina.model import FeedForward from pina.callback import PINAProgressBar -from pina.problem.zoo import Poisson2DSquareProblem as Poisson - -# make the problem -poisson_problem = Poisson() -n = 10 -condition_names = list(poisson_problem.conditions.keys()) -poisson_problem.discretise_domain(n, "grid", domains="boundary") -poisson_problem.discretise_domain(n, "grid", domains="D") -model = FeedForward( - len(poisson_problem.input_variables), len(poisson_problem.output_variables) +from pina import Trainer, Condition, LabelTensor +from pina.problem.zoo import Poisson2DSquareProblem + +# Initialize the problem +problem = Poisson2DSquareProblem() +problem.discretise_domain(10, "random") +problem.conditions["data"] = Condition( + input=LabelTensor(torch.randn(10, 2), labels=["x", "y"]), + target=LabelTensor(torch.randn(10, 1), labels=["u"]), +) + +# Initialize the model and solver +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) +solver = PINN(problem=problem, model=model) + +# Define metrics to be used in the progress bar +metrics_list = ["train", "val", "test", ["test", "data"], ["train", "val"]] + + +@pytest.mark.parametrize( + "metrics", ["train", "val", "test", ["test", "data"], ["train", "val"]] ) +def test_constructor(metrics): + PINAProgressBar(metrics=metrics) -# make the solver -solver = PINN(problem=poisson_problem, model=model) + # Should fail if metrics is not a string or list of strings + with pytest.raises(ValueError): + PINAProgressBar(metrics=123) -def test_progress_bar_constructor(): - PINAProgressBar() +@pytest.mark.parametrize( + "metrics", + [ + "train", + "val", + "test", + ["test", "data"], + ["train", "val"], + ], +) +@pytest.mark.parametrize("batch_size", [None, 8]) +def test_routine(metrics, batch_size): + + # Initialize the callback + callback = PINAProgressBar(metrics=metrics) + + # Convert to list if a single string is provided + if isinstance(metrics, str): + metrics = [metrics] + # Convert to list if a single string is provided + if isinstance(metrics, str): + metrics = [metrics] -def test_progress_bar_routine(): - # make the trainer + # Initialize the trainer with the callback and train the model trainer = Trainer( solver=solver, - callbacks=[PINAProgressBar(["val", condition_names[0]])], + callbacks=[callback], accelerator="cpu", max_epochs=5, + batch_size=batch_size, + log_every_n_steps=1, ) trainer.train() - # TODO there should be a check that the correct metrics are displayed + + # Get the expected metrics based on the input and batch size + suffix = "_loss_epoch" if batch_size is not None else "_loss" + expected_metrics = sorted([metric + suffix for metric in metrics]) + + # Check that the progress bar metrics are the expected ones + assert callback._sorted_metrics == expected_metrics + + # Assert that metrics in the progress bar are subset of expected metrics + displayed_metrics = { + key + for key in trainer.progress_bar_metrics + if key in callback._sorted_metrics + } + assert displayed_metrics.issubset(set(expected_metrics)) diff --git a/tests/test_callback/test_r3_refinement.py b/tests/test_callback/test_r3_refinement.py index f8b9519e9..933fddb6a 100644 --- a/tests/test_callback/test_r3_refinement.py +++ b/tests/test_callback/test_r3_refinement.py @@ -1,53 +1,104 @@ +import torch import pytest -from torch.nn import MSELoss from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem as Poisson from pina.callback import R3Refinement +from pina.problem.zoo import Poisson2DSquareProblem -# make the problem -poisson_problem = Poisson() -poisson_problem.discretise_domain(10, "grid", domains="boundary") -poisson_problem.discretise_domain(10, "grid", domains="D") -model = FeedForward( - len(poisson_problem.input_variables), len(poisson_problem.output_variables) -) -solver = PINN(problem=poisson_problem, model=model) - - -def test_constructor(): - # good constructor - R3Refinement(sample_every=10) - R3Refinement(sample_every=10, residual_loss=MSELoss) - R3Refinement(sample_every=10, condition_to_update=["D"]) - # wrong constructor + +@pytest.mark.parametrize("sample_every", [1, 3]) +@pytest.mark.parametrize("residual_loss", [torch.nn.MSELoss, torch.nn.L1Loss]) +@pytest.mark.parametrize("condition_to_update", [None, ["D"]]) +def test_constructor(sample_every, residual_loss, condition_to_update): + + # Initialize the callback + R3Refinement( + sample_every=sample_every, + residual_loss=residual_loss, + condition_to_update=condition_to_update, + ) + + # Should fail if sample_every is not a positive integer + with pytest.raises(AssertionError): + R3Refinement(sample_every=0) + + # Should fail if residual_loss is not a valid loss class with pytest.raises(ValueError): - R3Refinement(sample_every="str") + R3Refinement(sample_every=10, residual_loss="not_a_loss") + + # Should fail if condition_to_update is not a string or iterable of strings with pytest.raises(ValueError): - R3Refinement(sample_every=10, condition_to_update=3) + R3Refinement(sample_every=10, condition_to_update=123) + +@pytest.mark.parametrize("sample_every", [1, 3]) +@pytest.mark.parametrize("residual_loss", [torch.nn.MSELoss, torch.nn.L1Loss]) +@pytest.mark.parametrize("condition_to_update", [None, ["D"], ["boundary"]]) +def test_sample(sample_every, residual_loss, condition_to_update): + + # Define the problem, model, and solver for testing + problem = Poisson2DSquareProblem() + problem.discretise_domain(10, "grid", domains="boundary") + problem.discretise_domain(10, "grid", domains="D") + model = FeedForward( + len(problem.input_variables), len(problem.output_variables) + ) + solver = PINN(problem=problem, model=model) -@pytest.mark.parametrize("condition_to_update", [["D"], ["boundary", "D"]]) -def test_sample(condition_to_update): + # Initialize the callback + callback = R3Refinement( + sample_every=sample_every, + residual_loss=residual_loss, + condition_to_update=condition_to_update, + ) + + # Initialize the trainer trainer = Trainer( solver=solver, - callbacks=[ - R3Refinement( - sample_every=1, condition_to_update=condition_to_update - ) - ], + callbacks=callback, accelerator="cpu", max_epochs=5, ) - before_n_points = { - loc: len(trainer.solver.problem.input_pts[loc]) - for loc in condition_to_update + + # Initialize the conditions to update if None + if callback._condition_to_update is None: + callback._condition_to_update = [ + name + for name, cond in solver.problem.conditions.items() + if hasattr(cond, "domain") + ] + + # Check initial population size and dataset before training + n_points_before_train = { + cond: len(trainer.solver.problem.conditions[cond].data.input) + for cond in callback._condition_to_update } + + # Train the model to trigger refinement trainer.train() - after_n_points = { - loc: len(trainer.data_module.train_dataset.input[loc]) - for loc in condition_to_update + + # Check population size after training to ensure it has been updated + n_points_after_train = { + cond: len(trainer.solver.problem.conditions[cond].data.input) + for cond in callback._condition_to_update } - assert before_n_points == trainer.callbacks[0].initial_population_size - assert before_n_points == after_n_points + + # Assert population size has been updated according to the refinement + assert n_points_before_train == trainer.callbacks[0].initial_population_size + assert n_points_before_train == n_points_after_train + + # Should fail if the specified condition does not exist in the problem + with pytest.raises(RuntimeError): + callback = R3Refinement( + sample_every=sample_every, + residual_loss=residual_loss, + condition_to_update="non_existent_condition", + ) + trainer = Trainer( + solver=solver, + callbacks=callback, + accelerator="cpu", + max_epochs=5, + ) + callback.on_train_start(trainer, solver=solver) diff --git a/tests/test_callback/test_switch_optimizer.py b/tests/test_callback/test_switch_optimizer.py index c7490a231..115b7b768 100644 --- a/tests/test_callback/test_switch_optimizer.py +++ b/tests/test_callback/test_switch_optimizer.py @@ -1,15 +1,14 @@ import torch import pytest - from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward from pina.optim import TorchOptimizer from pina.callback import SwitchOptimizer -from pina.problem.zoo import Poisson2DSquareProblem as Poisson +from pina.problem.zoo import Poisson2DSquareProblem # Define the problem -problem = Poisson() +problem = Poisson2DSquareProblem() problem.discretise_domain(10) model = FeedForward(len(problem.input_variables), len(problem.output_variables)) @@ -26,27 +25,35 @@ @pytest.mark.parametrize("epoch_switch", [5, 10]) @pytest.mark.parametrize("new_opt", [lbfgs, adamW]) -def test_switch_optimizer_constructor(new_opt, epoch_switch): +def test_constructor(new_opt, epoch_switch): # Constructor SwitchOptimizer(new_optimizers=new_opt, epoch_switch=epoch_switch) - # Should fail if epoch_switch is less than 1 - with pytest.raises(ValueError): + # Should fail if epoch_switch is not a positive integer + with pytest.raises(AssertionError): SwitchOptimizer(new_optimizers=new_opt, epoch_switch=0) + # Should fail if new_optimizers is not an instance of OptimizerInterface + with pytest.raises(ValueError): + SwitchOptimizer( + new_optimizers="not_an_optimizer", epoch_switch=epoch_switch + ) + @pytest.mark.parametrize("epoch_switch", [5, 10]) @pytest.mark.parametrize("new_opt", [lbfgs, adamW]) -def test_switch_optimizer_routine(new_opt, epoch_switch): +def test_routine(new_opt, epoch_switch): # Check if the optimizer is initialized correctly solver.configure_optimizers() - # Initialize the trainer + # Initialize the callback switch_opt_callback = SwitchOptimizer( new_optimizers=new_opt, epoch_switch=epoch_switch ) + + # Initialize the trainer trainer = Trainer( solver=solver, callbacks=switch_opt_callback, diff --git a/tests/test_callback/test_switch_scheduler.py b/tests/test_callback/test_switch_scheduler.py index 36b177853..dc7d55cba 100644 --- a/tests/test_callback/test_switch_scheduler.py +++ b/tests/test_callback/test_switch_scheduler.py @@ -1,15 +1,14 @@ import torch import pytest - from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward from pina.optim import TorchScheduler from pina.callback import SwitchScheduler -from pina.problem.zoo import Poisson2DSquareProblem as Poisson +from pina.problem.zoo import Poisson2DSquareProblem # Define the problem -problem = Poisson() +problem = Poisson2DSquareProblem() problem.discretise_domain(10) model = FeedForward(len(problem.input_variables), len(problem.output_variables)) @@ -31,19 +30,27 @@ def test_switch_scheduler_constructor(new_sched, epoch_switch): # Constructor SwitchScheduler(new_schedulers=new_sched, epoch_switch=epoch_switch) - # Should fail if epoch_switch is less than 1 + # Should fail if epoch_switch is not a positive integer with pytest.raises(AssertionError): SwitchScheduler(new_schedulers=new_sched, epoch_switch=0) + # Should fail if new_schedulers is not an instance of SchedulerInterface + with pytest.raises(ValueError): + SwitchScheduler( + new_schedulers="not_a_scheduler", epoch_switch=epoch_switch + ) + @pytest.mark.parametrize("epoch_switch", [5, 10]) @pytest.mark.parametrize("new_sched", [step, exp]) def test_switch_scheduler_routine(new_sched, epoch_switch): - # Initialize the trainer + # Initialize the callback switch_sched_callback = SwitchScheduler( new_schedulers=new_sched, epoch_switch=epoch_switch ) + + # Initialize the trainer trainer = Trainer( solver=solver, callbacks=switch_sched_callback, diff --git a/tests/test_optim/test_torch_optimizer.py b/tests/test_optim/test_torch_optimizer.py new file mode 100644 index 000000000..dffc04c67 --- /dev/null +++ b/tests/test_optim/test_torch_optimizer.py @@ -0,0 +1,27 @@ +import torch +import pytest +from pina.optim import TorchOptimizer + +opt_list = [torch.optim.Adam, torch.optim.AdamW, torch.optim.SGD] +kwargs_list = [{"lr": 1e-3}, {"lr": 1e-3, "weight_decay": 1e-4}] + + +@pytest.mark.parametrize("optimizer_class", opt_list) +@pytest.mark.parametrize("kwargs", kwargs_list) +def test_constructor(optimizer_class, kwargs): + TorchOptimizer(optimizer_class, **kwargs) + + # Should fail if the optimizer is not subclass of torch.optim.Optimizer + with pytest.raises(ValueError): + TorchOptimizer(object, **kwargs) + + +@pytest.mark.parametrize("optimizer_class", opt_list) +@pytest.mark.parametrize("kwargs", kwargs_list) +def test_hook(optimizer_class, kwargs): + + # Create the optimizer instance + optimizer = TorchOptimizer(optimizer_class, **kwargs) + + # Hook the optimizer with model parameters + optimizer.hook(torch.nn.Linear(10, 10).parameters()) diff --git a/tests/test_optim/test_torch_scheduler.py b/tests/test_optim/test_torch_scheduler.py new file mode 100644 index 000000000..bc7dd96c9 --- /dev/null +++ b/tests/test_optim/test_torch_scheduler.py @@ -0,0 +1,37 @@ +import torch +import pytest +from pina.optim import TorchOptimizer, TorchScheduler + +opt_list = [torch.optim.Adam, torch.optim.AdamW, torch.optim.SGD] +sch_list = [ + torch.optim.lr_scheduler.ConstantLR, + torch.optim.lr_scheduler.ReduceLROnPlateau, +] + + +@pytest.mark.parametrize("scheduler_class", sch_list) +def test_constructor(scheduler_class): + TorchScheduler(scheduler_class) + + # Should fail if the scheduler is not subclass of torch LRScheduler + with pytest.raises(ValueError): + TorchScheduler(object) + + +@pytest.mark.parametrize("optimizer_class", opt_list) +@pytest.mark.parametrize("scheduler_class", sch_list) +def test_hook(optimizer_class, scheduler_class): + + # Create the optimizer instance + optimizer = TorchOptimizer(optimizer_class) + optimizer.hook(torch.nn.Linear(10, 10).parameters()) + + # Create the scheduler instance + scheduler = TorchScheduler(scheduler_class) + + # Hook the scheduler with the optimizer instance + scheduler.hook(optimizer) + + # Should fail if the optimizer is not an instance of OptimizerInterface + with pytest.raises(ValueError): + scheduler.hook(object) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py deleted file mode 100644 index 037de9929..000000000 --- a/tests/test_optimizer.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import pytest -from pina.optim import TorchOptimizer - -opt_list = [ - torch.optim.Adam, - torch.optim.AdamW, - torch.optim.SGD, - torch.optim.RMSprop, -] - - -@pytest.mark.parametrize("optimizer_class", opt_list) -def test_constructor(optimizer_class): - TorchOptimizer(optimizer_class, lr=1e-3) - - -@pytest.mark.parametrize("optimizer_class", opt_list) -def test_hook(optimizer_class): - opt = TorchOptimizer(optimizer_class, lr=1e-3) - opt.hook(torch.nn.Linear(10, 10).parameters()) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py deleted file mode 100644 index 157a818d2..000000000 --- a/tests/test_scheduler.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -import pytest -from pina.optim import TorchOptimizer, TorchScheduler - -opt_list = [ - torch.optim.Adam, - torch.optim.AdamW, - torch.optim.SGD, - torch.optim.RMSprop, -] - -sch_list = [torch.optim.lr_scheduler.ConstantLR] - - -@pytest.mark.parametrize("scheduler_class", sch_list) -def test_constructor(scheduler_class): - TorchScheduler(scheduler_class) - - -@pytest.mark.parametrize("optimizer_class", opt_list) -@pytest.mark.parametrize("scheduler_class", sch_list) -def test_hook(optimizer_class, scheduler_class): - opt = TorchOptimizer(optimizer_class, lr=1e-3) - opt.hook(torch.nn.Linear(10, 10).parameters()) - sch = TorchScheduler(scheduler_class) - sch.hook(opt)