LinearWeighting#

Module for the Linear Weighting class.

class LinearWeighting(initial_weights, final_weights, target_epoch)[source]#

Bases: BaseWeighting

Weighting strategy based on linear interpolation over training epochs.

This scheme progressively adjusts the weights assigned to each loss term, transitioning from a set of initial values to corresponding final values. The update follows a linear schedule and is applied at each epoch until the specified target epoch is reached.

Initialization of the LinearWeighting class.

Parameters:
  • initial_weights (dict) – The mapping of loss identifiers to their initial weights at the start of training. Keys represent loss terms (e.g., conditions) and values are the corresponding weights. Loss terms not explicitly specified default to a weight of 1.

  • final_weights (dict) – The mapping of loss identifiers to their target weights at the specified target_epoch. Keys must match those of initial_weights. Loss terms not explicitly specified default to a weight of 1.

  • target_epoch (int) – The epoch at which the weights reach their final values. The interpolation progresses linearly from epoch 0 to target_epoch. After target_epoch, the weights remain constant at their final values.

Raises:
  • ValueError – If initial_weights is not a dictionary.

  • ValueError – If final_weights is not a dictionary.

  • ValueError – If target_epoch is not a positive integer.

  • ValueError – If the keys of the two dictionaries are not consistent.

update_weights(losses)[source]#

Update the weights based on the current losses.

This method defines how the weighting strategy adapts over time. It is responsible for computing and storing updated weights that will be used during aggregation.

Parameters:

losses (dict) – The mapping from loss names to loss tensors.

Returns:

The updated weights.

Return type:

dict