Source code for pina._src.callback.processing.pina_progress_bar

"""Module for the Processing Callbacks."""

from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import (
    get_standard_metrics,
)
from pina._src.core.utils import check_consistency


[docs] class PINAProgressBar(TQDMProgressBar): """ Custom progress bar callback for PINA training workflows. This callback extends the default Lightning progress bar by filtering the displayed metrics. Metrics can refer either to condition-specific losses, identified by the names assigned to the problem conditions, or to global losses. Global losses are selected using ``"train"``, ``"val"``, or ``"test"``, and are internally expanded to the corresponding logged loss metrics. """ GLOBAL_LOSS_KEYS = ("train", "val", "test") BAR_FORMAT = ( "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, " "{rate_noinv_fmt}{postfix}]" ) def __init__(self, metrics="val", **kwargs): """ Initialization of the :class:`PINAProgressBar`. :param metrics: The names of the metrics to be shown in the progress bar. Each entry can be either a key of a condition defined in the problem or one of the global loss keys: ``"train"``, ``"val"``, or ``"test"``. These global keys are internally expanded to the corresponding logged loss names. Default is ``"val"``. :type metrics: str | list(str) | tuple(str) :param dict kwargs: Additional keyword arguments passed to :class:`lightning.pytorch.callbacks.TQDMProgressBar`. :raises TypeError: If ``metrics`` contains non-string elements. """ super().__init__(**kwargs) # Check consistency check_consistency(metrics, str) # Convert to list if a single string is provided if isinstance(metrics, str): metrics = [metrics] # Store the sorted metrics for later use in get_metrics self._sorted_metrics = sorted(metrics)
[docs] def get_metrics(self, trainer, __): """ Retrieve and filter metrics to be displayed in the progress bar. This method combines standard Lightning metrics with user-selected progress bar metrics, retaining only the metrics specified at initialization. :param Trainer trainer: The trainer managing the training loop. :param __: Placeholder argument, not used. :return: Dictionary containing the metrics to display. :rtype: dict .. note:: This method overrides the default Lightning behavior. It can be further customized by subclassing. """ # Retrieve standard metrics and user-selected progress bar metrics standard_metrics = get_standard_metrics(trainer) progress_bar_metrics = trainer.progress_bar_metrics # Filter progress bar metrics to include only specified keys if progress_bar_metrics: progress_bar_metrics = { key: progress_bar_metrics[key] for key in progress_bar_metrics if key in self._sorted_metrics } return {**standard_metrics, **progress_bar_metrics}
[docs] def setup(self, trainer, pl_module, stage): """ Configure the metrics to track before execution starts. The requested metrics must be either names assigned to problem conditions or global loss keys. The accepted global loss keys are ``"train"``, ``"val"``, and ``"test"``. :param Trainer trainer: The trainer instance managing the execution. :param BaseSolver pl_module: The solver module being executed. :param str stage: Current execution stage. :raises KeyError: If a metric key is neither a condition key nor one of ``"train"``, ``"val"``, or ``"test"``. """ # Get the condition keys from the problem condition_keys = trainer.solver.problem.conditions.keys() for key in self._sorted_metrics: if key not in condition_keys and key not in self.GLOBAL_LOSS_KEYS: raise KeyError( f"Key '{key}' is not a valid metric. It must be either a " f"problem condition key or one of {self.GLOBAL_LOSS_KEYS}." ) # Add the appropriate suffix to the metric names based on batch size suffix = "_loss_epoch" if trainer.batch_size is not None else "_loss" self._sorted_metrics = [ metric + suffix for metric in self._sorted_metrics ] return super().setup(trainer, pl_module, stage)