Normalizer callbacks#
- class NormalizerDataCallback(scale_fn=<built-in method std of type object>, shift_fn=<built-in method mean of type object>, stage='all', apply_to='input')[source]#
Bases:
CallbackA Callback used to normalize the dataset inputs or targets according to user-provided scale and shift functions.
The transformation is applied as:
\[x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}\]- Example:
>>> NormalizerDataCallback() >>> NormalizerDataCallback( ... scale_fn: torch.std, ... shift_fn: torch.mean, ... stage: "all", ... apply_to: "input", ... )
Initialization of the
NormalizerDataCallbackclass.- Parameters:
scale_fn (Callable) – The function to compute the scaling factor. Default is
torch.std.shift_fn (Callable) – The function to compute the shifting factor. Default is
torch.mean.stage (str) – The stage in which normalization is applied. Accepted values are “train”, “validate”, “test”, or “all”. Default is
"all".apply_to (str) – Whether to normalize “input” or “target” data. Default is
"input".
- Raises:
ValueError – If
scale_fnis not callable.ValueError – If
shift_fnis not callable.
- setup(trainer, pl_module, stage)[source]#
Apply normalization during setup.
- Parameters:
pl_module (SolverInterface) – A
SolverInterfaceinstance.stage (str) – The current stage.
- Raises:
RuntimeError – If the training dataset is not available when computing normalization parameters.
NotImplementedError – If the dataset is graph-based.
- Returns:
The result of the parent setup.
- Return type:
Any
- normalize_dataset(dataset)[source]#
Apply in-place normalization to the dataset.
- Parameters:
dataset (PinaDataset) – The dataset to be normalized.