Processing callbacks#
- class MetricTracker(metrics_to_track=None)[source]#
Bases:
Callback
Lightning Callback for Metric Tracking.
Tracks specified metrics during training.
- Parameters:
metrics_to_track (list[str], optional) – List of metrics to track. Defaults to train loss.
- setup(trainer, pl_module, stage)[source]#
Called when fit, validate, test, predict, or tune begins.
- Parameters:
pl_module (SolverInterface) – A
SolverInterface
instance.stage (str) – Either ‘fit’, ‘test’ or ‘predict’.
- class PINAProgressBar(metrics='val', **kwargs)[source]#
Bases:
TQDMProgressBar
PINA Implementation of a Lightning Callback for enriching the progress bar.
This class enables the display of only relevant metrics during training.
- Parameters:
metrics (str | list(str) | tuple(str)) – Logged metrics to be shown during the training. Must be a subset of the conditions keys defined in
pina.condition.Condition
.- Keyword Arguments:
The additional keyword arguments specify the progress bar and can be choosen from the pytorch-lightning TQDMProgressBar API
Example
>>> pbar = PINAProgressBar(['mean']) >>> # ... Perform training ... >>> trainer = Trainer(solver, callbacks=[pbar])
- get_metrics(trainer, pl_module)[source]#
Combine progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Override this method to customize the items shown in the progress bar. The progress bar metrics are sorted according to
metrics
.Here is an example of how to override the defaults:
def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items