Processing callbacks#
- class MetricTracker[source]#
Bases:
Callback
PINA Implementation of a Lightning Callback for Metric Tracking.
This class provides functionality to track relevant metrics during the training process.
- Variables:
_collection – A list to store collected metrics after each
training epoch.
- Parameters:
trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.
- Returns:
A dictionary containing aggregated metric values.
- Return type:
Example
>>> tracker = MetricTracker() >>> # ... Perform training ... >>> metrics = tracker.metrics
- class PINAProgressBar(metrics='mean', **kwargs)[source]#
Bases:
TQDMProgressBar
PINA Implementation of a Lightning Callback for enriching the progress bar.
This class provides functionality to display only relevant metrics during the training process.
- Parameters:
metrics (str | list(str) | tuple(str)) – Logged metrics to display during the training. It should be a subset of the conditions keys defined in
pina.condition.Condition
.- Keyword Arguments:
The additional keyword arguments specify the progress bar and can be choosen from the pytorch-lightning TQDMProgressBar API
Example
>>> pbar = PINAProgressBar(['mean']) >>> # ... Perform training ... >>> trainer = Trainer(solver, callbacks=[pbar])
- get_metrics(trainer, pl_module)[source]#
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. The progress bar metrics are sorted according to
metrics
.Here is an example of how to override the defaults:
def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items