WeightingInterface#

Module for the Weighting Interface.

class WeightingInterface(update_every_n_epochs=1, aggregator='sum')[source]#

Bases: object

Abstract base class for all loss weighting schemas. All weighting schemas should inherit from this class.

Initialization of the WeightingInterface class.

Parameters:
  • update_every_n_epochs (int) – The number of training epochs between weight updates. If set to 1, the weights are updated at every epoch. This parameter is ignored by static weighting schemes. Default is 1.

  • aggregator (str | Callable) – The aggregation method. Either: - ‘sum’ → torch.sum - ‘mean’ → torch.mean - callable → custom aggregation function

abstract weights_update(losses)[source]#

Update the weighting scheme based on the given losses.

This method must be implemented by subclasses. Its role is to update the values of the weights. The updated weights will then be used by aggregate() to compute the final aggregated loss.

Parameters:

losses (dict) – The dictionary of losses.

Returns:

The updated weights.

Return type:

dict

final aggregate(losses)[source]#

Update the weights (if needed) and aggregate the given losses.

This method first checks whether the loss weights need to be updated based on the current epoch and the update_every_n_epochs setting. If an update is required, it calls weights_update() to refresh the weights. Afterwards, it aggregates the (weighted) losses into a single scalar tensor using the configured aggregator function. This method must not be overridden.

Parameters:

losses (dict) – The dictionary of losses.

Returns:

The aggregated loss tensor.

Return type:

torch.Tensor

last_saved_weights()[source]#

Get the last saved weights.

Returns:

The last saved weights.

Return type:

dict

property solver#

The solver employing this weighting schema.

Returns:

The solver.

Return type:

SolverInterface