Source code for pina._src.callback.processing.data_normalizer

"""Module for the Data Normalizer callback."""

from typing import Callable
import torch
from lightning.pytorch import Callback
from pina._src.core.utils import check_consistency
from pina._src.core.label_tensor import LabelTensor
from pina._src.condition.condition import InputTargetCondition


[docs] class DataNormalizer(Callback): r""" Callback for dataset normalization on input-target conditions. This callback computes and applies a normalization transform to either input or target tensors within a dataset. The transformation is defined as: .. math:: x_{\text{norm}} = \frac{x - \mu}{\sigma}, where :math:`\mu` and :math:`\sigma` are computed using the provided ``shift_fn`` and ``scale_fn`` functions, respectively. Normalization parameters are estimated from the training dataset and then applied in-place to the selected datasets depending on the chosen stage. .. note:: This callback ignores all conditions that are not instances of :class:`~pina.condition.InputTargetCondition`. :Example: >>> DataNormalizer( ... scale_fn=torch.std, ... shift_fn=torch.mean, ... stage="all", ... apply_to="input", ... ) """ # Define valid options for stage and apply_to parameters _VALID_STAGES = {"train", "validate", "test", "all"} _VALID_APPLY_TO = {"input", "target"} def __init__( self, scale_fn=torch.std, shift_fn=torch.mean, stage="all", apply_to="input", ): """ Initialization of the :class:`DataNormalizer` class. :param Callable scale_fn: The function used to compute the scaling factor. Default is ``torch.std``. :param Callable shift_fn: The function used to compute the shifting factor. Default is ``torch.mean``. :param str stage: The stage during which normalization is applied. Available options are ``"train"``, ``"validate"``, ``"test"``, and ``"all"``. Default is ``"all"``. :param str apply_to: Specifies whether normalization is applied to ``"input"`` or ``"target"`` tensors. Default is ``"input"``. :raises ValueError: If ``scale_fn`` is not Callable. :raises ValueError: If ``shift_fn`` is not Callable. :raises ValueError: If ``stage`` is invalid. :raises ValueError: If ``apply_to`` is invalid. """ super().__init__() # Check consistency check_consistency(scale_fn, Callable) check_consistency(shift_fn, Callable) check_consistency(stage, str) check_consistency(apply_to, str) # Validate stage parameter if stage not in self._VALID_STAGES: raise ValueError( "Invalid value for 'stage'. Available options are " f"{self._VALID_STAGES}. Got {stage}." ) # Validate apply_to parameter if apply_to not in self._VALID_APPLY_TO: raise ValueError( "Invalid value for 'apply_to'. Available options are " f"{self._VALID_APPLY_TO}. Got {apply_to}." ) # Initialize attributes self.scale_fn = scale_fn self.shift_fn = shift_fn self.stage = stage self.apply_to = apply_to self._normalizer = {} self._normalized_conditions = set()
[docs] def setup(self, trainer, pl_module, stage): """ Compute and apply normalization during the setup phase. :param Trainer trainer: The trainer instance managing the execution. :param BaseSolver pl_module: The solver module being executed. :param str stage: Current execution stage. :raises NotImplementedError: If the dataset is graph-based and therefore unsupported. """ # Check if any condition contains graph-based data if any( hasattr(ds.condition.data, "graph_key") for ds in trainer.datamodule.train_datasets.values() ): raise NotImplementedError( "DataNormalizer is not compatible with graph-based datasets." ) # Extract input-target conditions conditions_to_normalize = [ name for name, cond in pl_module.problem.conditions.items() if isinstance(cond, InputTargetCondition) ] # Extract the dictionary of all datasets dataset = trainer.datamodule.train_datasets # Compute scale and shift parameters if not already computed if not self.normalizer: # Iterate over conditions and compute normalization parameters for cond in conditions_to_normalize: pts = self._get_data(dataset, cond) shift = self.shift_fn(pts) scale = self.scale_fn(pts) self._normalizer[cond] = { "shift": shift, "scale": scale, } # Apply normalization to training datasets if stage == "fit" and self.stage in ["train", "all"]: self.normalize_dataset(trainer.datamodule.train_datasets) if stage == "fit" and self.stage in ["validate", "all"]: self.normalize_dataset(trainer.datamodule.val_datasets) if stage == "test" and self.stage in ["test", "all"]: self.normalize_dataset(trainer.datamodule.test_datasets) return super().setup(trainer, pl_module, stage)
[docs] def normalize_dataset(self, dataset): """ Apply normalization to all datasets in-place. Each condition is updated using precomputed normalization parameters. The transformation preserves tensor types. :param dict dataset: The mapping between condition names and their associated dataset subsets. """ # Iterate over conditions and apply normalization for cond, norm_params in self.normalizer.items(): if cond in self._normalized_conditions: continue # Extract the points to normalize and the normalization parameters data_container = getattr(dataset[cond].condition, self.apply_to) points = data_container.data scale = norm_params["scale"] shift = norm_params["shift"] # Apply normalization scaled_pts = (points - shift) / scale if isinstance(data_container, LabelTensor): scaled_pts = LabelTensor(scaled_pts, data_container.labels) # Update the dataset in-place data_container.data = scaled_pts self._normalized_conditions.add(cond)
def _get_data(self, dataset, cond): """ Extract the selected data field from the dataset for a given condition. :param dict dataset: The mapping between condition names and their associated dataset subsets. :param str cond: The condition name. :return: The selected input or target data. :rtype: torch.Tensor """ return getattr(dataset[cond].condition, self.apply_to).data @property def normalizer(self): """ The dictionary mapping each condition to its corresponding ``shift`` and ``scale`` values. :return: The dictionary of normalization parameters. :rtype: dict """ return self._normalizer