Processing callbacks#

class MetricTracker(metrics_to_track=None)[source]#

Bases: Callback

Lightning Callback for Metric Tracking.

Tracks specified metrics during training.

Parameters:

metrics_to_track (list[str], optional) – List of metrics to track. Defaults to train loss.

setup(trainer, pl_module, stage)[source]#

Called when fit, validate, test, predict, or tune begins.

Parameters:
on_train_epoch_end(trainer, pl_module)[source]#

Collect and track metrics at the end of each training epoch.

Parameters:
  • trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.

  • pl_module – The model being trained (not used here).

property metrics#

Aggregate collected metrics over all epochs.

Returns:

A dictionary containing aggregated metric values.

Return type:

dict

class PINAProgressBar(metrics='val', **kwargs)[source]#

Bases: TQDMProgressBar

PINA Implementation of a Lightning Callback for enriching the progress bar.

This class enables the display of only relevant metrics during training.

Parameters:

metrics (str | list(str) | tuple(str)) – Logged metrics to be shown during the training. Must be a subset of the conditions keys defined in pina.condition.Condition.

Keyword Arguments:

The additional keyword arguments specify the progress bar and can be choosen from the pytorch-lightning TQDMProgressBar API

Example

>>> pbar = PINAProgressBar(['mean'])
>>> # ... Perform training ...
>>> trainer = Trainer(solver, callbacks=[pbar])
get_metrics(trainer, pl_module)[source]#

Combine progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Override this method to customize the items shown in the progress bar. The progress bar metrics are sorted according to metrics.

Here is an example of how to override the defaults:

def get_metrics(self, trainer, model):
    # don't show the version number
    items = super().get_metrics(trainer, model)
    items.pop("v_num", None)
    return items
Returns:

Dictionary with the items to be displayed in the progress bar.

Return type:

tuple(dict)

setup(trainer, pl_module, stage)[source]#

Check that the initialized metrics are available and correctly logged.

Parameters:
  • trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.

  • pl_module – Placeholder argument.