Source code for pina._src.weighting.linear_weighting
"""Module for the Linear Weighting class."""
from pina._src.weighting.base_weighting import BaseWeighting
from pina._src.core.utils import check_consistency, check_positive_integer
[docs]
class LinearWeighting(BaseWeighting):
"""
Weighting strategy based on linear interpolation over training epochs.
This scheme progressively adjusts the weights assigned to each loss term,
transitioning from a set of initial values to corresponding final values.
The update follows a linear schedule and is applied at each epoch until the
specified target epoch is reached.
"""
def __init__(self, initial_weights, final_weights, target_epoch):
"""
Initialization of the :class:`LinearWeighting` class.
:param dict initial_weights: The mapping of loss identifiers to their
initial weights at the start of training. Keys represent loss terms
(e.g., conditions) and values are the corresponding weights. Loss
terms not explicitly specified default to a weight of ``1``.
:param dict final_weights: The mapping of loss identifiers to their
target weights at the specified ``target_epoch``. Keys must match
those of ``initial_weights``. Loss terms not explicitly specified
default to a weight of ``1``.
:param int target_epoch: The epoch at which the weights reach their
final values. The interpolation progresses linearly from epoch ``0``
to ``target_epoch``. After ``target_epoch``, the weights remain
constant at their final values.
:raises ValueError: If ``initial_weights`` is not a dictionary.
:raises ValueError: If ``final_weights`` is not a dictionary.
:raises ValueError: If ``target_epoch`` is not a positive integer.
:raises ValueError: If the keys of the two dictionaries are not
consistent.
"""
super().__init__(update_every_n_epochs=1, aggregator="sum")
# Check consistency
check_consistency([initial_weights, final_weights], dict)
check_positive_integer(value=target_epoch, strict=True)
# Check that the keys of the two dictionaries match
if initial_weights.keys() != final_weights.keys():
raise ValueError(
"The keys of the provided dictionaries for initial and final "
f"weights must match. Got {initial_weights.keys()} as initial "
f"weight keys and {final_weights.keys()} as final weight keys."
)
# Initialization
self.initial_weights = initial_weights
self.final_weights = final_weights
self.target_epoch = target_epoch
[docs]
def update_weights(self, losses):
"""
Update the weights based on the current losses.
This method defines how the weighting strategy adapts over time. It is
responsible for computing and storing updated weights that will be used
during aggregation.
:param dict losses: The mapping from loss names to loss tensors.
:return: The updated weights.
:rtype: dict
"""
# Determine the progress towards the target epoch
epoch = min(self.solver.trainer.current_epoch, self.target_epoch)
progress = epoch / self.target_epoch
return {
condition: self.initial_weights[condition]
+ (self.final_weights[condition] - self.initial_weights[condition])
* progress
for condition in losses.keys()
}