Normalizer callbacks#

class NormalizerDataCallback(scale_fn=<built-in method std of type object>, shift_fn=<built-in method mean of type object>, stage='all', apply_to='input')[source]#

Bases: Callback

A Callback used to normalize the dataset inputs or targets according to user-provided scale and shift functions.

The transformation is applied as:

\[x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}\]
Example:

>>> NormalizerDataCallback()
>>> NormalizerDataCallback(
...     scale_fn: torch.std,
...     shift_fn: torch.mean,
...     stage: "all",
...     apply_to: "input",
... )

Initialization of the NormalizerDataCallback class.

Parameters:
  • scale_fn (Callable) – The function to compute the scaling factor. Default is torch.std.

  • shift_fn (Callable) – The function to compute the shifting factor. Default is torch.mean.

  • stage (str) – The stage in which normalization is applied. Accepted values are “train”, “validate”, “test”, or “all”. Default is "all".

  • apply_to (str) – Whether to normalize “input” or “target” data. Default is "input".

Raises:
setup(trainer, pl_module, stage)[source]#

Apply normalization during setup.

Parameters:
Raises:
  • RuntimeError – If the training dataset is not available when computing normalization parameters.

  • NotImplementedError – If the dataset is graph-based.

Returns:

The result of the parent setup.

Return type:

Any

normalize_dataset(dataset)[source]#

Apply in-place normalization to the dataset.

Parameters:

dataset (PinaDataset) – The dataset to be normalized.

property normalizer#

Get the dictionary of normalization parameters.

Returns:

The dictionary of normalization parameters.

Return type:

dict