Source code for pina._src.loss.base_dual_loss
"""Module for the BaseDualLoss class."""
import torch
from pina._src.loss.dual_loss_interface import DualLossInterface
[docs]
class BaseDualLoss(DualLossInterface):
"""
Base class for all losses requiring both an input and a target tensor,
implementing common functionality.
All specific loss types should inherit from this class and implement its
abstract methods.
This class is not meant to be instantiated directly.
"""
# Define available reduction methods
_REDUCTION_METHOD = {
"sum": lambda x: torch.sum(x, keepdim=True, dim=-1),
"mean": lambda x: torch.mean(x, keepdim=True, dim=-1),
"none": lambda x: x,
}
def __init__(self, reduction="mean"):
"""
Initialization of the :class:`BaseDualLoss` class.
:param str reduction: The reduction method to aggregate pointwise loss
values. Available options include: ``"none"`` for unreduced loss,
``"mean"`` for the average of the loss values, and ``"sum"`` for
their total sum. Default is ``"mean"``.
:raises ValueError: If the specified reduction method is not among the
available options.
"""
# Check that the reduction method is available
if reduction not in self._REDUCTION_METHOD:
raise ValueError(
f"Invalid reduction method. Available options: "
f"{list(self._REDUCTION_METHOD.keys())}. Got {reduction}."
)
# Initialization
super().__init__(reduction=reduction, size_average=None, reduce=None)
def _reduction(self, loss):
"""
Apply the configured reduction operation to pointwise loss values.
:param torch.Tensor loss: The tensor of pointwise losses.
:return: The reduced loss tensor.
:rtype: torch.Tensor
"""
return self._REDUCTION_METHOD[self.reduction](loss)