Normalizer callbacks#
- class NormalizerDataCallback(scale_fn=<built-in method std of type object>, shift_fn=<built-in method mean of type object>, stage='all', apply_to='input')[source]#
Bases:
Callback
A Callback used to normalize the dataset inputs or targets according to user-provided scale and shift functions.
The transformation is applied as:
\[x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}\]- Example:
>>> NormalizerDataCallback() >>> NormalizerDataCallback( ... scale_fn: torch.std, ... shift_fn: torch.mean, ... stage: "all", ... apply_to: "input", ... )
Initialization of the
NormalizerDataCallback
class.- Parameters:
scale_fn (Callable) – The function to compute the scaling factor. Default is
torch.std
.shift_fn (Callable) – The function to compute the shifting factor. Default is
torch.mean
.stage (str) – The stage in which normalization is applied. Accepted values are “train”, “validate”, “test”, or “all”. Default is
"all"
.apply_to (str) – Whether to normalize “input” or “target” data. Default is
"input"
.
- Raises:
ValueError – If
scale_fn
is not callable.ValueError – If
shift_fn
is not callable.
- setup(trainer, pl_module, stage)[source]#
Apply normalization during setup.
- Parameters:
pl_module (SolverInterface) – A
SolverInterface
instance.stage (str) – The current stage.
- Raises:
RuntimeError – If the training dataset is not available when computing normalization parameters.
NotImplementedError – If the dataset is graph-based.
- Returns:
The result of the parent setup.
- Return type:
Any
- normalize_dataset(dataset)[source]#
Apply in-place normalization to the dataset.
- Parameters:
dataset (PinaDataset) – The dataset to be normalized.