Data Normalizer#

Module for the Data Normalizer callback.

class DataNormalizer(scale_fn=<built-in method std of type object>, shift_fn=<built-in method mean of type object>, stage='all', apply_to='input')[source]#

Bases: Callback

Callback for dataset normalization on input-target conditions.

This callback computes and applies a normalization transform to either input or target tensors within a dataset. The transformation is defined as:

\[x_{\text{norm}} = \frac{x - \mu}{\sigma},\]

where \(\mu\) and \(\sigma\) are computed using the provided shift_fn and scale_fn functions, respectively. Normalization parameters are estimated from the training dataset and then applied in-place to the selected datasets depending on the chosen stage.

Note

This callback ignores all conditions that are not instances of InputTargetCondition.

Example:
>>> DataNormalizer(
...     scale_fn=torch.std,
...     shift_fn=torch.mean,
...     stage="all",
...     apply_to="input",
... )

Initialization of the DataNormalizer class.

Parameters:
  • scale_fn (Callable) – The function used to compute the scaling factor. Default is torch.std.

  • shift_fn (Callable) – The function used to compute the shifting factor. Default is torch.mean.

  • stage (str) – The stage during which normalization is applied. Available options are "train", "validate", "test", and "all". Default is "all".

  • apply_to (str) – Specifies whether normalization is applied to "input" or "target" tensors. Default is "input".

Raises:
setup(trainer, pl_module, stage)[source]#

Compute and apply normalization during the setup phase.

Parameters:
  • trainer (Trainer) – The trainer instance managing the execution.

  • pl_module (BaseSolver) – The solver module being executed.

  • stage (str) – Current execution stage.

Raises:

NotImplementedError – If the dataset is graph-based and therefore unsupported.

normalize_dataset(dataset)[source]#

Apply normalization to all datasets in-place.

Each condition is updated using precomputed normalization parameters. The transformation preserves tensor types.

Parameters:

dataset (dict) – The mapping between condition names and their associated dataset subsets.

property normalizer#

The dictionary mapping each condition to its corresponding shift and scale values.

Returns:

The dictionary of normalization parameters.

Return type:

dict