Source code for pina.callback.processing.metric_tracker
"""Module for the Metric Tracker."""
import copy
import torch
from lightning.pytorch.callbacks import Callback
[docs]
class MetricTracker(Callback):
"""
Lightning Callback for Metric Tracking.
"""
def __init__(self, metrics_to_track=None):
"""
Tracks specified metrics during training.
:param metrics_to_track: List of metrics to track.
Defaults to train loss.
:type metrics_to_track: list[str], optional
"""
super().__init__()
self._collection = []
# Default to tracking 'train_loss' if not specified
self.metrics_to_track = metrics_to_track
[docs]
def setup(self, trainer, pl_module, stage):
"""
Called when fit, validate, test, predict, or tune begins.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: Either 'fit', 'test' or 'predict'.
"""
if self.metrics_to_track is None and trainer.batch_size is None:
self.metrics_to_track = ["train_loss"]
elif self.metrics_to_track is None:
self.metrics_to_track = ["train_loss_epoch"]
return super().setup(trainer, pl_module, stage)
[docs]
def on_train_epoch_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: The model being trained (not used here).
"""
# Track metrics after the first epoch onwards
if trainer.current_epoch > 0:
# Append only the tracked metrics to avoid unnecessary data
tracked_metrics = {
k: v
for k, v in trainer.logged_metrics.items()
if k in self.metrics_to_track
}
self._collection.append(copy.deepcopy(tracked_metrics))
@property
def metrics(self):
"""
Aggregate collected metrics over all epochs.
:return: A dictionary containing aggregated metric values.
:rtype: dict
"""
if not self._collection:
return {}
# Get intersection of keys across all collected dictionaries
common_keys = set(self._collection[0]).intersection(
*self._collection[1:]
)
# Stack the metric values for common keys and return
return {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys
if k in self.metrics_to_track
}