Source code for pina._src.callback.processing.metric_tracker

"""Module for the Metric Tracker."""

import copy
import torch
from lightning.pytorch.callbacks import Callback
from pina._src.core.utils import check_consistency


[docs] class MetricTracker(Callback): """ Callback for collecting selected metrics logged during training. """ def __init__(self, metrics_to_track=None): """ Initialization of the :class:`MetricTracker` class. :param metrics_to_track: The names of the metrics to collect. If ``None``, defaults to ``["train_loss"]`` when no batch size is available, otherwise to ``["train_loss_epoch"]``. Default is ``None``. :type metrics_to_track: str | list[str] :raises ValueError: If any of the provided metric names are not strings. """ super().__init__() # Check consistency if metrics_to_track is not None: check_consistency(metrics_to_track, str) # Convert to list if a single string is provided if isinstance(metrics_to_track, str): metrics_to_track = [metrics_to_track] # Initialize the collection list and store the metrics to track self.metrics_to_track = metrics_to_track self._collection = []
[docs] def setup(self, trainer, pl_module, stage): """ Configure the metrics to track before execution starts. When a batch size is provided (i.e. ``trainer.batch_size`` is not ``None``), metric names are expanded to match Lightning's logging convention: for each metric ``m``, both ``m_step`` and ``m_epoch`` are tracked. For example, ``"train_loss"`` becomes ``["train_loss_step", "train_loss_epoch"]``. :param Trainer trainer: The trainer instance managing the execution. :param BaseSolver pl_module: The solver module being executed. :param str stage: Current execution stage. """ # Set default metrics to train_loss if no batch size is available if self.metrics_to_track is None: self.metrics_to_track = ["train_loss"] # If a batch size is provided, expand metric names to match convention if trainer.batch_size is not None: self.metrics_to_track = [ f"{metric}_{suffix}" for metric in self.metrics_to_track for suffix in ("step", "epoch") ] return super().setup(trainer, pl_module, stage)
[docs] def on_train_epoch_end(self, trainer, __): """ Store the selected logged metrics at the end of each training epoch. :param Trainer trainer: The trainer instance managing the execution. :param __: Placeholder argument, not used. """ # Only collect metrics after the first epoch to ensure they are logged if trainer.current_epoch > 0: # Collect the metrics that are being tracked tracked_metrics = { k: v for k, v in trainer.logged_metrics.items() if k in self.metrics_to_track } self._collection.append(copy.deepcopy(tracked_metrics))
@property def metrics(self): """ Return the collected metrics stacked over the tracked epochs. :return: The dictionary mapping each metric name to a tensor containing its values across epochs. Returns an empty dictionary if no metrics have been collected. :rtype: dict[str, torch.Tensor] """ if not self._collection: return {} # Identify the common keys across all collected metric dictionaries common_keys = set(self._collection[0]).intersection( *self._collection[1:] ) return { k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys if k in self.metrics_to_track }