Data Normalizer#
Module for the Data Normalizer callback.
- class DataNormalizer(scale_fn=<built-in method std of type object>, shift_fn=<built-in method mean of type object>, stage='all', apply_to='input')[source]#
Bases:
CallbackCallback for dataset normalization on input-target conditions.
This callback computes and applies a normalization transform to either input or target tensors within a dataset. The transformation is defined as:
\[x_{\text{norm}} = \frac{x - \mu}{\sigma},\]where \(\mu\) and \(\sigma\) are computed using the provided
shift_fnandscale_fnfunctions, respectively. Normalization parameters are estimated from the training dataset and then applied in-place to the selected datasets depending on the chosen stage.Note
This callback ignores all conditions that are not instances of
InputTargetCondition.- Example:
>>> DataNormalizer( ... scale_fn=torch.std, ... shift_fn=torch.mean, ... stage="all", ... apply_to="input", ... )
Initialization of the
DataNormalizerclass.- Parameters:
scale_fn (Callable) – The function used to compute the scaling factor. Default is
torch.std.shift_fn (Callable) – The function used to compute the shifting factor. Default is
torch.mean.stage (str) – The stage during which normalization is applied. Available options are
"train","validate","test", and"all". Default is"all".apply_to (str) – Specifies whether normalization is applied to
"input"or"target"tensors. Default is"input".
- Raises:
ValueError – If
scale_fnis not Callable.ValueError – If
shift_fnis not Callable.ValueError – If
stageis invalid.ValueError – If
apply_tois invalid.
- setup(trainer, pl_module, stage)[source]#
Compute and apply normalization during the setup phase.
- Parameters:
- Raises:
NotImplementedError – If the dataset is graph-based and therefore unsupported.