Metric Tracker#

Module for the Metric Tracker.

class MetricTracker(metrics_to_track=None)[source]

Bases: Callback

Callback for collecting selected metrics logged during training.

Initialization of the MetricTracker class.

Parameters:

metrics_to_track (str | list[str]) – The names of the metrics to collect. If None, defaults to ["train_loss"] when no batch size is available, otherwise to ["train_loss_epoch"]. Default is None.

Raises:

ValueError – If any of the provided metric names are not strings.

setup(trainer, pl_module, stage)[source]

Configure the metrics to track before execution starts.

When a batch size is provided (i.e. trainer.batch_size is not None), metric names are expanded to match Lightning’s logging convention: for each metric m, both m_step and m_epoch are tracked. For example, "train_loss" becomes ["train_loss_step", "train_loss_epoch"].

Parameters:
  • trainer (Trainer) – The trainer instance managing the execution.

  • pl_module (BaseSolver) – The solver module being executed.

  • stage (str) – Current execution stage.

on_train_epoch_end(trainer, __)[source]

Store the selected logged metrics at the end of each training epoch.

Parameters:
  • trainer (Trainer) – The trainer instance managing the execution.

  • __ – Placeholder argument, not used.

property metrics

Return the collected metrics stacked over the tracked epochs.

Returns:

The dictionary mapping each metric name to a tensor containing its values across epochs. Returns an empty dictionary if no metrics have been collected.

Return type:

dict[str, torch.Tensor]