Processing callbacks#

class MetricTracker[source]#

Bases: Callback

PINA Implementation of a Lightning Callback for Metric Tracking.

This class provides functionality to track relevant metrics during the training process.

Variables:

_collection – A list to store collected metrics after each

training epoch.

Parameters:

trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.

Returns:

A dictionary containing aggregated metric values.

Return type:

dict

Example

>>> tracker = MetricTracker()
>>> # ... Perform training ...
>>> metrics = tracker.metrics
on_train_epoch_end(trainer, pl_module)[source]#

Collect and track metrics at the end of each training epoch.

Parameters:
  • trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.

  • pl_module – Placeholder argument.

property metrics#

Aggregate collected metrics during training.

Returns:

A dictionary containing aggregated metric values.

Return type:

dict

class PINAProgressBar(metrics='mean', **kwargs)[source]#

Bases: TQDMProgressBar

PINA Implementation of a Lightning Callback for enriching the progress bar.

This class provides functionality to display only relevant metrics during the training process.

Parameters:

metrics (str | list(str) | tuple(str)) – Logged metrics to display during the training. It should be a subset of the conditions keys defined in pina.condition.Condition.

Keyword Arguments:

The additional keyword arguments specify the progress bar and can be choosen from the pytorch-lightning TQDMProgressBar API

Example

>>> pbar = PINAProgressBar(['mean'])
>>> # ... Perform training ...
>>> trainer = Trainer(solver, callbacks=[pbar])
get_metrics(trainer, pl_module)[source]#

Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. The progress bar metrics are sorted according to metrics.

Here is an example of how to override the defaults:

def get_metrics(self, trainer, model):
    # don't show the version number
    items = super().get_metrics(trainer, model)
    items.pop("v_num", None)
    return items
Returns:

Dictionary with the items to be displayed in the progress bar.

Return type:

tuple(dict)

on_fit_start(trainer, pl_module)[source]#

Check that the metrics defined in the initialization are available, i.e. are correctly logged.

Parameters:
  • trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.

  • pl_module – Placeholder argument.