WeightingInterface#
Module for the Weighting Interface.
- class WeightingInterface(update_every_n_epochs=1, aggregator='sum')[source]#
Bases:
object
Abstract base class for all loss weighting schemas. All weighting schemas should inherit from this class.
Initialization of the
WeightingInterface
class.- Parameters:
update_every_n_epochs (int) – The number of training epochs between weight updates. If set to 1, the weights are updated at every epoch. This parameter is ignored by static weighting schemes. Default is 1.
aggregator (str | Callable) – The aggregation method. Either: - ‘sum’ → torch.sum - ‘mean’ → torch.mean - callable → custom aggregation function
- abstract weights_update(losses)[source]#
Update the weighting scheme based on the given losses.
This method must be implemented by subclasses. Its role is to update the values of the weights. The updated weights will then be used by
aggregate()
to compute the final aggregated loss.
- final aggregate(losses)[source]#
Update the weights (if needed) and aggregate the given losses.
This method first checks whether the loss weights need to be updated based on the current epoch and the
update_every_n_epochs
setting. If an update is required, it callsweights_update()
to refresh the weights. Afterwards, it aggregates the (weighted) losses into a single scalar tensor using the configured aggregator function. This method must not be overridden.- Parameters:
losses (dict) – The dictionary of losses.
- Returns:
The aggregated loss tensor.
- Return type:
- last_saved_weights()[source]#
Get the last saved weights.
- Returns:
The last saved weights.
- Return type:
- property solver#
The solver employing this weighting schema.
- Returns:
The solver.
- Return type:
SolverInterface