Source code for pina.callback.processing.pina_progress_bar

"""Module for the Processing Callbacks."""

from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import (
    get_standard_metrics,
)
from pina.utils import check_consistency


[docs] class PINAProgressBar(TQDMProgressBar): """ PINA Implementation of a Lightning Callback for enriching the progress bar. """ BAR_FORMAT = ( "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, " "{rate_noinv_fmt}{postfix}]" ) def __init__(self, metrics="val", **kwargs): """ This class enables the display of only relevant metrics during training. :param metrics: Logged metrics to be shown during the training. Must be a subset of the conditions keys defined in :obj:`pina.condition.Condition`. :type metrics: str | list(str) | tuple(str) :Keyword Arguments: The additional keyword arguments specify the progress bar and can be choosen from the `pytorch-lightning TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_ Example: >>> pbar = PINAProgressBar(['mean']) >>> # ... Perform training ... >>> trainer = Trainer(solver, callbacks=[pbar]) """ super().__init__(**kwargs) # check consistency if not isinstance(metrics, (list, tuple)): metrics = [metrics] check_consistency(metrics, str) self._sorted_metrics = metrics
[docs] def get_metrics(self, trainer, pl_module): r"""Combine progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Override this method to customize the items shown in the progress bar. The progress bar metrics are sorted according to ``metrics``. Here is an example of how to override the defaults: .. code-block:: python def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items :return: Dictionary with the items to be displayed in the progress bar. :rtype: tuple(dict) """ standard_metrics = get_standard_metrics(trainer) pbar_metrics = trainer.progress_bar_metrics if pbar_metrics: pbar_metrics = { key: pbar_metrics[key] for key in pbar_metrics if key in self._sorted_metrics } return {**standard_metrics, **pbar_metrics}
[docs] def setup(self, trainer, pl_module, stage): """ Check that the initialized metrics are available and correctly logged. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer :param pl_module: Placeholder argument. """ # Check if all keys in sort_keys are present in the dictionary for key in self._sorted_metrics: if ( key not in trainer.solver.problem.conditions.keys() and key != "train" and key != "val" ): raise KeyError(f"Key '{key}' is not present in the dictionary") # add the loss pedix if trainer.batch_size is not None: pedix = "_loss_epoch" else: pedix = "_loss" self._sorted_metrics = [ metric + pedix for metric in self._sorted_metrics ] return super().setup(trainer, pl_module, stage)