Processing callbacks#
- class MetricTracker(metrics_to_track=None)[source]#
Bases:
CallbackLightning Callback for Metric Tracking.
Tracks specified metrics during training.
- Parameters:
metrics_to_track (list[str], optional) – List of metrics to track. Defaults to train loss.
- setup(trainer, pl_module, stage)[source]#
Called when fit, validate, test, predict, or tune begins.
- Parameters:
pl_module (SolverInterface) – A
SolverInterfaceinstance.stage (str) – Either ‘fit’, ‘test’ or ‘predict’.
- class PINAProgressBar(metrics='val', **kwargs)[source]#
Bases:
TQDMProgressBarPINA Implementation of a Lightning Callback for enriching the progress bar.
This class enables the display of only relevant metrics during training.
- Parameters:
metrics (str | list(str) | tuple(str)) – Logged metrics to be shown during the training. Must be a subset of the conditions keys defined in
pina.condition.Condition.- Keyword Arguments:
The additional keyword arguments specify the progress bar and can be choosen from the pytorch-lightning TQDMProgressBar API
Example
>>> pbar = PINAProgressBar(['mean']) >>> # ... Perform training ... >>> trainer = Trainer(solver, callbacks=[pbar])
- get_metrics(trainer, pl_module)[source]#
Combine progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Override this method to customize the items shown in the progress bar. The progress bar metrics are sorted according to
metrics.Here is an example of how to override the defaults:
def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items