Source code for pina.loss.scalar_weighting

"""Module for the Scalar Weighting."""

from .weighting_interface import WeightingInterface
from ..utils import check_consistency


class _NoWeighting(WeightingInterface):
    """
    Weighting scheme that does not apply any weighting to the losses.
    """

    def aggregate(self, losses):
        """
        Aggregate the losses.

        :param dict losses: The dictionary of losses.
        :return: The aggregated losses.
        :rtype: torch.Tensor
        """
        return sum(losses.values())


[docs] class ScalarWeighting(WeightingInterface): """ Weighting scheme that assigns a scalar weight to each loss term. """ def __init__(self, weights): """ Initialization of the :class:`ScalarWeighting` class. :param weights: The weights to be assigned to each loss term. If a single scalar value is provided, it is assigned to all loss terms. If a dictionary is provided, the keys are the conditions and the values are the weights. If a condition is not present in the dictionary, the default value is used. :type weights: float | int | dict """ super().__init__() check_consistency([weights], (float, dict, int)) if isinstance(weights, (float, int)): self.default_value_weights = weights self.weights = {} else: self.default_value_weights = 1 self.weights = weights
[docs] def aggregate(self, losses): """ Aggregate the losses. :param dict losses: The dictionary of losses. :return: The aggregated losses. :rtype: torch.Tensor """ return sum( self.weights.get(condition, self.default_value_weights) * loss for condition, loss in losses.items() )