Source code for pina.callback.normalizer_data_callback

"""Module for the Normalizer callback."""

import torch
from lightning.pytorch import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset


[docs] class NormalizerDataCallback(Callback): r""" A Callback used to normalize the dataset inputs or targets according to user-provided scale and shift functions. The transformation is applied as: .. math:: x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}} :Example: >>> NormalizerDataCallback() >>> NormalizerDataCallback( ... scale_fn: torch.std, ... shift_fn: torch.mean, ... stage: "all", ... apply_to: "input", ... ) """ def __init__( self, scale_fn=torch.std, shift_fn=torch.mean, stage="all", apply_to="input", ): """ Initialization of the :class:`NormalizerDataCallback` class. :param Callable scale_fn: The function to compute the scaling factor. Default is ``torch.std``. :param Callable shift_fn: The function to compute the shifting factor. Default is ``torch.mean``. :param str stage: The stage in which normalization is applied. Accepted values are "train", "validate", "test", or "all". Default is ``"all"``. :param str apply_to: Whether to normalize "input" or "target" data. Default is ``"input"``. :raises ValueError: If ``scale_fn`` is not callable. :raises ValueError: If ``shift_fn`` is not callable. """ super().__init__() # Validate parameters self.apply_to = self._validate_apply_to(apply_to) self.stage = self._validate_stage(stage) # Validate functions if not is_function(scale_fn): raise ValueError(f"scale_fn must be Callable, got {scale_fn}") if not is_function(shift_fn): raise ValueError(f"shift_fn must be Callable, got {shift_fn}") self.scale_fn = scale_fn self.shift_fn = shift_fn # Initialize normalizer dictionary self._normalizer = {} def _validate_apply_to(self, apply_to): """ Validate the ``apply_to`` parameter. :param str apply_to: The candidate value for the ``apply_to`` parameter. :raises ValueError: If ``apply_to`` is neither "input" nor "target". :return: The validated ``apply_to`` value. :rtype: str """ check_consistency(apply_to, str) if apply_to not in {"input", "target"}: raise ValueError( f"apply_to must be either 'input' or 'target', got {apply_to}" ) return apply_to def _validate_stage(self, stage): """ Validate the ``stage`` parameter. :param str stage: The candidate value for the ``stage`` parameter. :raises ValueError: If ``stage`` is not one of "train", "validate", "test", or "all". :return: The validated ``stage`` value. :rtype: str """ check_consistency(stage, str) if stage not in {"train", "validate", "test", "all"}: raise ValueError( "stage must be one of 'train', 'validate', 'test', or 'all'," f" got {stage}" ) return stage
[docs] def setup(self, trainer, pl_module, stage): """ Apply normalization during setup. :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. :param SolverInterface pl_module: A :class:`~pina.solver.solver.SolverInterface` instance. :param str stage: The current stage. :raises RuntimeError: If the training dataset is not available when computing normalization parameters. :return: The result of the parent setup. :rtype: Any :raises NotImplementedError: If the dataset is graph-based. """ # Ensure datsets are not graph-based if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): raise NotImplementedError( "NormalizerDataCallback is not compatible with " "graph-based datasets." ) # Extract conditions conditions_to_normalize = [ name for name, cond in pl_module.problem.conditions.items() if isinstance(cond, InputTargetCondition) ] # Compute scale and shift parameters if not self.normalizer: if not trainer.datamodule.train_dataset: raise RuntimeError( "Training dataset is not available. Cannot compute " "normalization parameters." ) self._compute_scale_shift( conditions_to_normalize, trainer.datamodule.train_dataset ) # Apply normalization based on the specified stage if stage == "fit" and self.stage in ["train", "all"]: self.normalize_dataset(trainer.datamodule.train_dataset) if stage == "fit" and self.stage in ["validate", "all"]: self.normalize_dataset(trainer.datamodule.val_dataset) if stage == "test" and self.stage in ["test", "all"]: self.normalize_dataset(trainer.datamodule.test_dataset) return super().setup(trainer, pl_module, stage)
def _compute_scale_shift(self, conditions, dataset): """ Compute scale and shift parameters for each condition in the dataset. :param list conditions: The list of condition names. :param dataset: The `~pina.data.dataset.PinaDataset` dataset. """ for cond in conditions: if cond in dataset.conditions_dict: data = dataset.conditions_dict[cond][self.apply_to] shift = self.shift_fn(data) scale = self.scale_fn(data) self._normalizer[cond] = { "shift": shift, "scale": scale, } @staticmethod def _norm_fn(value, scale, shift): """ Normalize a value according to the scale and shift parameters. :param value: The input tensor to normalize. :type value: torch.Tensor | LabelTensor :param float scale: The scaling factor. :param float shift: The shifting factor. :return: The normalized tensor. :rtype: torch.Tensor | LabelTensor """ scaled_value = (value - shift) / scale if isinstance(value, LabelTensor): scaled_value = LabelTensor(scaled_value, value.labels) return scaled_value
[docs] def normalize_dataset(self, dataset): """ Apply in-place normalization to the dataset. :param PinaDataset dataset: The dataset to be normalized. """ # Initialize update dictionary update_dataset_dict = {} # Iterate over conditions and apply normalization for cond, norm_params in self.normalizer.items(): points = dataset.conditions_dict[cond][self.apply_to] scale = norm_params["scale"] shift = norm_params["shift"] normalized_points = self._norm_fn(points, scale, shift) update_dataset_dict[cond] = { self.apply_to: ( LabelTensor(normalized_points, points.labels) if isinstance(points, LabelTensor) else normalized_points ) } # Update the dataset in-place dataset.update_data(update_dataset_dict)
@property def normalizer(self): """ Get the dictionary of normalization parameters. :return: The dictionary of normalization parameters. :rtype: dict """ return self._normalizer